mirror of
synced 2025-03-18 00:53:32 +00:00
412 lines
15 KiB
412 lines
15 KiB
from __future__ import unicode_literals, print_function
import copy
import re
import sys
from abc import abstractmethod
from functools import wraps
from operator import attrgetter
from traceback import print_exc
from typing import Text
from clearml_agent.helper.console import ListFormatter, print_text
from clearml_agent.helper.dicts import filter_keys
import six
from clearml_agent.backend_api import services
from clearml_agent.errors import APIError, CommandFailedError
from clearml_agent.helper.base import Singleton, return_list, print_parameters, dump_yaml, load_yaml, error, warning
from clearml_agent.interface.base import ObjectID
from clearml_agent.session import Session
class NameResolutionError(CommandFailedError):
def __init__(self, message, suggestions=''):
super(NameResolutionError, self).__init__(message)
self.message = message
self.suggestions = suggestions
def __str__(self):
return self.message + self.suggestions
def resolve_names(func):
def safe_resolve(command, arg):
result = command._resolve_name(arg.name, arg.service)
return result, None
except NameResolutionError:
return arg.name, sys.exc_info()
def _resolve_single(command, arg):
if isinstance(arg, ObjectID):
return command._resolve_name(arg.name, arg.service)
elif isinstance(arg, (list, tuple)) and all(isinstance(x, ObjectID) for x in arg):
result = [safe_resolve(command, x) for x in arg]
if len(result) == 1:
name, ex = result[0]
if ex:
return [name]
for _, ex in result:
if ex:
return [name for (name, _) in result]
return arg
def newfunc(self, *args, **kwargs):
args = [_resolve_single(self, arg) for arg in args]
kwargs = {key: _resolve_single(self, value) for key, value in kwargs.items()}
return func(self, *args, **kwargs)
return newfunc
class BaseCommandSection(object):
Base class for command sections which do not interact with the allegro API.
Has basic utilities for user interaction.
warning = staticmethod(warning)
error = staticmethod(error)
def log(message, *args):
print("clearml-agent: {}".format(message % args))
def exit(cls, message, code=1): # type: (Text, int) -> ()
class ServiceCommandSection(BaseCommandSection):
Base class for command sections which interact with the allegro API.
Contains API functionality which is common across services.
_worker_name = None
def __init__(self, *args, **kwargs):
super(ServiceCommandSection, self).__init__()
kwargs = self._verify_command_states(kwargs)
self._session = self._get_session(*args, **kwargs)
self._list_formatter = ListFormatter(self.service)
def _verify_command_states(cls, kwargs):
Conform and enforce command argument
This is where you can automatically turn on/off switches based on different states.
:param kwargs:
:return: kwargs
return kwargs
def _get_session(*args, **kwargs):
return Session(*args, **kwargs)
def service(self):
""" The name of the REST service used by this command """
def get(self, endpoint, *args, service=None, session=None, **kwargs):
session = session or self._session
service = service or self.service
return session.get(service=service, action=endpoint, *args, **kwargs)
def post(self, endpoint, *args, service=None, session=None, **kwargs):
session = session or self._session
service = service or self.service
return session.post(service=service, action=endpoint, *args, **kwargs)
def get_with_act_as(self, endpoint, *args, **kwargs):
return self._session.get_with_act_as(service=self.service, action=endpoint, *args, **kwargs)
def name(self):
return self.service.title()
def name_single(self):
return self.name.rstrip('s')
def service_single(self):
return self.service.rstrip('s')
def __info(self, id=None, yaml=None, **kwargs):
ids = return_list(id)
if not ids:
yaml_dump = {}
for i in ids:
get_fields = {self.service_single: i}
info = self.get('get_by_id', **get_fields)
yaml_dump[i] = info[self.service_single]
except APIError:
self.error('Failed retrieving info for {} {}'.format(self.service_single, i))
self.output_info(yaml_dump, yaml_path=yaml, **kwargs)
return yaml_dump
def _info(self, *args, **kwargs):
self.__info(*args, **kwargs)
def output_info(entries, quiet=False, yaml_path=None, **_):
if not quiet and entries:
print_parameters(entries, indent=4)
if yaml_path:
print('Storing entries to [{}]'.format(yaml_path))
dump_yaml(entries, yaml_path)
def _make_query(json, table, sort=None, projection_from_table=False, extra_fields=None):
json = json.copy()
if isinstance(table, six.string_types):
table = table.split('#')
if extra_fields:
if projection_from_table:
json['only_fields'] = table
if sort:
# does nothing if 'order_by' is not in get_fields
json['order_by'] = sort.split('#')[0]
return json, table
def _get_all(self, endpoint, json, retpoint=None):
return self.get(endpoint, **json).get(retpoint or self.service, [])
def _update(self,
if not yaml and primary_key not in kwargs:
raise ValueError('Update must supply either yaml file or %s-id' % self.service_single)
data_entries = {}
original_data_entries = {}
if yaml:
data_entries = load_yaml(yaml)
if send_diff or (not yaml and primary_key in kwargs):
i = kwargs.get(primary_key) or next(iter(data_entries))
original_info = self.__info(id=i, quiet=True)[i]
if send_diff:
original_data_entries[i] = copy.deepcopy(original_info)
if yaml and i not in data_entries:
if len(data_entries) > 1:
raise ValueError(
'Error: yaml file [%s] contains more than one task id' % yaml)
first_key = next(iter(data_entries))
if first_key != i:
if kwargs.get('force'):
if not quiet:
print('Warning: overriding yaml task id [%s] with id=%s' % (first_key, i))
raise ValueError(
'Error: yaml task id [%s] != id [%s], use --force to override' % (first_key, i))
data_entries = {i: data_entries[first_key]}
data_entries[i][primary_key] = i
elif not yaml:
data_entries[i] = kwargs
if model_desc:
first_key = next(iter(data_entries))
with open(model_desc) as f:
proto_data = f.read()
info = data_entries[first_key]
info['execution']['model_desc']['prototxt'] = proto_data
if override:
first_key = next(iter(data_entries))
info = data_entries[first_key]
for p in override:
key, val = p.split('=') if isinstance(p, six.string_types) else p
info_key = info
keys = key.split('.')
for k in keys[:-1]:
if not info_key.get(k):
info_key[k] = dict()
info_key = info_key[k]
info_key[keys[-1]] = val
# always make sure tags is a list of strings
# split string to tokens ':'
# examples tags='auto_generated:draft'
if info.get('tags') is not None:
# remove empty strings from list
info['tags'] = [t for t in info['tags'].split(':') if t]
# send only change set
if send_diff:
# only send the values that changed
for i, info in data_entries.items():
org_info = original_data_entries[i]
out_info = {}
recursive_diff(org_info, info, out_info)
data_entries[i] = out_info
for i, info in data_entries.items():
if not info or len(info) == 0 or list(info) == [primary_key]:
if not quiet:
print('Skipping: nothing to update for %s id [%s]' % (self.service_single, i))
if not quiet:
print('Updating %s id [%s]' % (self.service_single, i))
info[self.service_single] = i
result = self.get(endpoint, **info)
if not result['updated']:
raise ValueError('Failed updating %s id [%s]' % (self.service_single, i))
if not quiet:
print('%s [%s] updated fields: %s' % (self.name_single, i, result.get('fields', '')))
def remove(self, ids, **kwargs):
return self._apply_command(
request_cls=getattr(services, self.service).DeleteRequest,
def _apply_command(self, request_cls, object_ids, response_validation_field=None, **kwargs):
object_ids = return_list(object_ids)
def call_one(object_id):
error_message = '[{object_id}]: failed'.format(**locals())
response = self._session.send_api(request_cls(object_id, **kwargs))
except APIError as e:
if not self._session.debug_mode:
self.error('{}: {}'.format(error_message, e))
traceback = e.format_traceback()
if traceback:
print('Own traceback:')
return False
if not response_validation_field or getattr(response, response_validation_field) == 1:
return True
return False
succeeded = [call_one(object_id) for object_id in object_ids].count(True)
message = '{}/{} succeeded'.format(succeeded, len(object_ids))
(self.log if succeeded == len(object_ids) else self.exit)(message)
def get_service(self, service_class):
return service_class(config=self._session.config)
def _resolve_name(self, name, service=None):
Resolve an object name to an object ID.
- If the argument "looks like" an ID, return it.
- Else, get all object with names containing the argument
- if an object with the argument as its name exists, return the object's ID
- Else, print a list of suggestions and exit
:param str name: ID (returned unmodified) or Name to resolve
:param str service: Service to resolve from (type of object). Defaults to service represented by the class
:return: ID of object
:rtype: str
service = service or self.service
if re.match(r'^[-a-f0-9]{30,}$', name):
return name
request_cls = getattr(services, service).GetAllRequest
except AttributeError:
raise NameResolutionError('Name resolution unavailable for {}'.format(service))
request = request_cls.from_dict(dict(name=re.escape(name), only_fields=['name', 'id']))
# from_dict will ignore unrecognised keyword arguments - not all GetAll's have only_fields
response = getattr(self._session.send_api(request), service)
matches = [db_object for db_object in response if name.lower() == db_object.name.lower()]
def truncated_bullet_list(format_string, elements, callback, **kwargs):
if len(elements) > self.MAX_SUGGESTIONS:
dict(details=' (showing {}/{})'.format(self.MAX_SUGGESTIONS, len(elements)), suffix='\n...'))
kwargs.update(dict(details='', suffix=''))
bullet_list = '\n'.join('* {}'.format(callback(item)) for item in elements[:self.MAX_SUGGESTIONS])
return format_string.format(bullet_list, **kwargs)
if len(matches) == 1:
return matches.pop().id
elif len(matches) > 1:
message = truncated_bullet_list(
'Found multiple {service} with name "{name}"{details}:\n{}{suffix}',
message = 'Could not find {} with name/id "{}"'.format(service.rstrip('s'), name)
if not response:
raise NameResolutionError(message)
suggestions = truncated_bullet_list(
'. Did you mean this?{details}\n{}{suffix}',
sorted(response, key=attrgetter('name')),
lambda db_object: '({}) {}'.format(db_object.id, db_object.name)
raise NameResolutionError(message, suggestions)
def recursive_diff(org, upd, out):
if isinstance(upd, dict) and isinstance(org, dict):
diff_keys = [
k for k in upd
if k not in org or upd[k] != org[k]
if diff_keys:
has_nested_dict = False
for k in diff_keys:
if isinstance(upd[k], dict):
out[k] = {}
has_nested_dict = True
k_has_nested = recursive_diff(
org.get(k, {}), upd[k], out[k])
if not k_has_nested:
out[k] = upd[k]
elif upd[k] is not None:
out[k] = upd[k]
return has_nested_dict
elif isinstance(upd, list) and isinstance(org, list):
diff_list = [k for k in upd if k not in org]
return False