import os from collections import namedtuple from functools import partial from tempfile import mkstemp import six from pathlib2 import Path from ..backend_api import Session from ..backend_api.services import models from .base import IdObjectBase from .util import make_message from ..storage import StorageHelper from ..utilities.async_manager import AsyncManagerMixin ModelPackage = namedtuple('ModelPackage', 'weights design') class ModelDoesNotExistError(Exception): pass class _StorageUriMixin(object): @property def upload_storage_uri(self): """ A URI into which models are uploaded """ return self._upload_storage_uri @upload_storage_uri.setter def upload_storage_uri(self, value): self._upload_storage_uri = value.rstrip('/') if value else None def create_dummy_model(upload_storage_uri=None, *args, **kwargs): class DummyModel(models.Model, _StorageUriMixin): def __init__(self, upload_storage_uri=None, *args, **kwargs): super(DummyModel, self).__init__(*args, **kwargs) self.upload_storage_uri = upload_storage_uri def update(self, **kwargs): for k, v in kwargs.items(): setattr(self, k, v) return DummyModel(upload_storage_uri=upload_storage_uri, *args, **kwargs) class Model(IdObjectBase, AsyncManagerMixin, _StorageUriMixin): """ Manager for backend model objects """ _EMPTY_MODEL_ID = 'empty' _local_model_to_id_uri = {} @property def model_id(self): return self.id @property def storage(self): return StorageHelper.get(self.upload_storage_uri) def __init__(self, upload_storage_uri, cache_dir, model_id=None, upload_storage_suffix='models', session=None, log=None): super(Model, self).__init__(id=model_id, session=session, log=log) self._upload_storage_suffix = upload_storage_suffix if model_id == self._EMPTY_MODEL_ID: # Set an empty data object self._data = models.Model() else: self._data = None self._cache_dir = cache_dir self.upload_storage_uri = upload_storage_uri def publish(self): self.send(models.SetReadyRequest(model=self.id, publish_task=False)) self.reload() def _reload(self): """ Reload the model object """ if self.id == self._EMPTY_MODEL_ID: return res = self.send(models.GetByIdRequest(model=self.id)) return res.response.model def _upload_model(self, model_file, async_enable=False, target_filename=None, cb=None): if not self.upload_storage_uri: raise ValueError('Model has no storage URI defined (nowhere to upload to)') helper = self.storage target_filename = target_filename or Path(model_file).name dest_path = '/'.join((self.upload_storage_uri, self._upload_storage_suffix or '.', target_filename)) result = helper.upload( src_path=model_file, dest_path=dest_path, async_enable=async_enable, cb=partial(self._upload_callback, cb=cb), ) if async_enable: def msg(num_results): self.log.info("Waiting for previous model to upload (%d pending, %s)" % (num_results, dest_path)) self._add_async_result(result, wait_on_max_results=2, wait_cb=msg) return dest_path def _upload_callback(self, res, cb=None): if res is None: self.log.debug('Starting model upload') elif res is False: self.log.info('Failed model upload') else: self.log.info('Completed model upload to {}'.format(res)) if cb: cb(res) @staticmethod def _wrap_design(design): """ Wrap design text with a dictionary. In the backend, the design is a dictionary with a 'design' key in it. For the client, it is a text. This function wraps a design string with the proper dictionary. :param design: If it is a dictionary, it mast have a 'design' key in it. In that case, return design as-is. If it is a string, return the dictionary {'design': design}. If it is None (or any False value), return the dictionary {'design': ''} :return: A proper design dictionary according to design parameter. """ if isinstance(design, dict): if 'design' not in design: raise ValueError('design dictionary must have \'design\' key in it') return design return {'design': design if design else ''} @staticmethod def _unwrap_design(design): """ Unwrap design text from a dictionary. In the backend, the design is a dictionary with a 'design' key in it. For the client, it is a text. This function unwraps a design string from the dictionary. :param design: If it is a dictionary with a 'design' key in it, return design['design']. If it is a dictionary without 'design' key, return the first value in it's values list. If it is an empty dictionary, None, or any other False value, return an empty string. If it is a string, return design as-is. :return: The design string according to design parameter. """ if not design: return '' if isinstance(design, six.string_types): return design if isinstance(design, dict): if 'design' in design: return design['design'] return list(design.values())[0] raise ValueError('design must be a string or a dictionary with at least one value') def update(self, model_file=None, design=None, labels=None, name=None, comment=None, tags=None, task_id=None, project_id=None, parent_id=None, uri=None, framework=None, upload_storage_uri=None, target_filename=None, iteration=None): """ Update model weights file and various model properties """ if self.id is None: if upload_storage_uri: self.upload_storage_uri = upload_storage_uri self._create_empty_model(self.upload_storage_uri) elif upload_storage_uri: self.upload_storage_uri = upload_storage_uri if model_file and uri: Model._local_model_to_id_uri[str(model_file)] = (self.model_id, uri) # upload model file if needed and get uri uri = uri or (self._upload_model(model_file, target_filename=target_filename) if model_file else self.data.uri) # update fields design = self._wrap_design(design) if design else self.data.design name = name or self.data.name comment = comment or self.data.comment labels = labels or self.data.labels task = task_id or self.data.task project = project_id or self.data.project parent = parent_id or self.data.parent if tags: extra = {'system_tags': tags or self.data.system_tags} \ if hasattr(self.data, 'system_tags') else {'tags': tags or self.data.tags} else: extra = {} self.send(models.EditRequest( model=self.id, uri=uri, name=name, comment=comment, labels=labels, design=design, task=task, project=project, parent=parent, framework=framework or self.data.framework, iteration=iteration, **extra )) self.reload() def edit(self, design=None, labels=None, name=None, comment=None, tags=None, uri=None, framework=None, iteration=None): if tags: extra = {'system_tags': tags or self.data.system_tags} \ if hasattr(self.data, 'system_tags') else {'tags': tags or self.data.tags} else: extra = {} self.send(models.EditRequest( model=self.id, uri=uri, name=name, comment=comment, labels=labels, design=self._wrap_design(design) if design else None, framework=framework, iteration=iteration, **extra )) self.reload() def update_and_upload(self, model_file, design=None, labels=None, name=None, comment=None, tags=None, task_id=None, project_id=None, parent_id=None, framework=None, async_enable=False, target_filename=None, cb=None, iteration=None): """ Update the given model for a given task ID """ if async_enable: def callback(uploaded_uri): if uploaded_uri is None: return # If not successful, mark model as failed_uploading if uploaded_uri is False: uploaded_uri = '{}/failed_uploading'.format(self._upload_storage_uri) Model._local_model_to_id_uri[str(model_file)] = (self.model_id, uploaded_uri) self.update( uri=uploaded_uri, task_id=task_id, name=name, comment=comment, tags=tags, design=design, labels=labels, project_id=project_id, parent_id=parent_id, framework=framework, iteration=iteration, ) if cb: cb(model_file) uri = self._upload_model(model_file, async_enable=async_enable, target_filename=target_filename, cb=callback) return uri else: uri = self._upload_model(model_file, async_enable=async_enable, target_filename=target_filename) Model._local_model_to_id_uri[str(model_file)] = (self.model_id, uri) self.update( uri=uri, task_id=task_id, name=name, comment=comment, tags=tags, design=design, labels=labels, project_id=project_id, parent_id=parent_id, framework=framework, ) return uri def _complete_update_for_task(self, uri, task_id=None, name=None, comment=None, tags=None, override_model_id=None, cb=None): if self._data: name = name or self.data.name comment = comment or self.data.comment tags = tags or (self.data.system_tags if hasattr(self.data, 'system_tags') else self.data.tags) uri = (uri or self.data.uri) if not override_model_id else None if tags: extra = {'system_tags': tags} if Session.check_min_api_version('2.3') else {'tags': tags} else: extra = {} res = self.send( models.UpdateForTaskRequest(task=task_id, uri=uri, name=name, comment=comment, override_model_id=override_model_id, **extra)) if self.id is None: # update the model id. in case it was just created, this will trigger a reload of the model object self.id = res.response.id else: self.reload() try: if cb: cb(uri) except Exception as ex: self.log.warning('Failed calling callback on complete_update_for_task: %s' % str(ex)) pass def update_for_task_and_upload( self, model_file, task_id, name=None, comment=None, tags=None, override_model_id=None, target_filename=None, async_enable=False, cb=None, iteration=None): """ Update the given model for a given task ID """ if async_enable: callback = partial( self._complete_update_for_task, task_id=task_id, name=name, comment=comment, tags=tags, override_model_id=override_model_id, cb=cb) uri = self._upload_model(model_file, target_filename=target_filename, async_enable=async_enable, cb=callback) return uri else: uri = self._upload_model(model_file, target_filename=target_filename, async_enable=async_enable) self._complete_update_for_task(uri, task_id, name, comment, tags, override_model_id) if tags: extra = {'system_tags': tags} if Session.check_min_api_version('2.3') else {'tags': tags} else: extra = {} _ = self.send(models.UpdateForTaskRequest(task=task_id, uri=uri, name=name, comment=comment, override_model_id=override_model_id, iteration=iteration, **extra)) return uri def update_for_task(self, task_id, uri=None, name=None, comment=None, tags=None, override_model_id=None): self._complete_update_for_task(uri, task_id, name, comment, tags, override_model_id) @property def model_design(self): """ Get the model design. For now, this is stored as a single key in the design dict. """ try: return self._unwrap_design(self.data.design) except ValueError: # no design is yet specified return None @property def labels(self): try: return self.data.labels except ValueError: # no labels is yet specified return None @property def name(self): try: return self.data.name except ValueError: # no name is yet specified return None @property def comment(self): try: return self.data.comment except ValueError: # no comment is yet specified return None @property def tags(self): return self.data.system_tags if hasattr(self.data, 'system_tags') else self.data.tags @property def task(self): try: return self.data.task except ValueError: # no task is yet specified return None @property def uri(self): try: return self.data.uri except ValueError: # no uri is yet specified return None @property def locked(self): if self.id is None: return False return bool(self.data.ready) def download_model_weights(self): """ Download the model weights into a local file in our cache """ uri = self.data.uri if not uri or not uri.strip(): return None # check if we already downloaded the file downloaded_models = [k for k, (i, u) in Model._local_model_to_id_uri.items() if i == self.id and u == uri] for dl_file in downloaded_models: if Path(dl_file).exists(): return dl_file # remove non existing model file Model._local_model_to_id_uri.pop(dl_file, None) local_download = StorageHelper.get(uri).get_local_copy(uri) # save local model, so we can later query what was the original one Model._local_model_to_id_uri[str(local_download)] = (self.model_id, uri) return local_download @property def cache_dir(self): return self._cache_dir def save_model_design_file(self): """ Download model description file into a local file in our cache_dir """ design = self.model_design filename = self.data.name + '.txt' p = Path(self.cache_dir) / filename # we always write the original model design to file, to prevent any mishaps # if p.is_file(): # return str(p) p.parent.mkdir(parents=True, exist_ok=True) p.write_text(six.text_type(design)) return str(p) def get_model_package(self): """ Get a named tuple containing the model's weights and design """ return ModelPackage(weights=self.download_model_weights(), design=self.save_model_design_file()) def get_model_design(self): """ Get model description (text) """ return self.model_design @classmethod def get_all(cls, session, log=None, **kwargs): req = models.GetAllRequest(**kwargs) res = cls._send(session=session, req=req, log=log) return res def clone(self, name, comment=None, child=True, tags=None, task=None, ready=True): """ Clone this model into a new model. :param name: Name for the new model :param comment: Optional comment for the new model :param child: Should the new model be a child of this model? (default True) :return: The new model's ID """ data = self.data assert isinstance(data, models.Model) parent = self.id if child else None extra = {'system_tags': tags or data.system_tags} \ if Session.check_min_api_version('2.3') else {'tags': tags or data.tags} req = models.CreateRequest( uri=data.uri, name=name, labels=data.labels, comment=comment or data.comment, framework=data.framework, design=data.design, ready=ready, project=data.project, parent=parent, task=task, **extra ) res = self.send(req) return res.response.id def _create_empty_model(self, upload_storage_uri=None): upload_storage_uri = upload_storage_uri or self.upload_storage_uri name = make_message('Anonymous model %(time)s') uri = '{}/uploading_file'.format(upload_storage_uri or 'file://') req = models.CreateRequest(uri=uri, name=name, labels={}) res = self.send(req) if not res: return False self.id = res.response.id return True