From 023f1721c139a7f9e4ae65c0fbfc3b3be91aa7c5 Mon Sep 17 00:00:00 2001 From: allegroai <> Date: Sun, 22 Mar 2020 18:19:07 +0200 Subject: [PATCH] Add Task.get_models() retrieving stored models on previously executed tasks --- trains/backend_interface/task/task.py | 42 ++++++++++++++++++++++++++- trains/binding/artifacts.py | 4 ++- trains/task.py | 14 +++++++++ 3 files changed, 58 insertions(+), 2 deletions(-) diff --git a/trains/backend_interface/task/task.py b/trains/backend_interface/task/task.py index dded2745..9d6daec9 100644 --- a/trains/backend_interface/task/task.py +++ b/trains/backend_interface/task/task.py @@ -2,6 +2,7 @@ import itertools import logging import os +import re from enum import Enum from tempfile import gettempdir from multiprocessing import RLock @@ -13,6 +14,7 @@ except ImportError: from collections import Iterable import six +from collections import OrderedDict from six.moves.urllib.parse import quote from ...utilities.locks import RLock as FileRLock @@ -560,7 +562,7 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin): ready=True, page=0, page_size=10, - order_by='-created', + order_by=['-created'], only_fields=['id', 'created'] ) ) @@ -847,6 +849,44 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin): """ return self._initial_iteration_offset + def _get_models(self, model_type='output'): + model_type = model_type.lower().strip() + assert model_type == 'output' or model_type == 'input' + + if model_type == 'input': + regex = '((?i)(Using model id: )(\w+)?)' + compiled = re.compile(regex) + ids = [i[-1] for i in re.findall(compiled, self.comment)] + ( + [self.input_model_id] if self.input_model_id else []) + # remove duplicates and preserve order + ids = list(OrderedDict.fromkeys(ids)) + from ...model import Model as TrainsModel + in_model = [] + for i in ids: + m = TrainsModel(model_id=i) + try: + # make sure the model is is valid + m._get_model_data() + in_model.append(m) + except: + pass + return in_model + else: + res = self.send( + models.GetAllRequest( + task=[self.id], + order_by=['created'], + only_fields=['id'] + ) + ) + if not res.response.models: + return [] + ids = [m.id for m in res.response.models] + ([self.output_model_id] if self.output_model_id else []) + # remove duplicates and preserve order + ids = list(OrderedDict.fromkeys(ids)) + from ...model import Model as TrainsModel + return [TrainsModel(model_id=i) for i in ids] + def _get_default_report_storage_uri(self): if not self._files_server: self._files_server = Session.get_files_server_host() diff --git a/trains/binding/artifacts.py b/trains/binding/artifacts.py index 325f6e9e..1bff88c3 100644 --- a/trains/binding/artifacts.py +++ b/trains/binding/artifacts.py @@ -168,12 +168,14 @@ class Artifact(object): from trains.storage.helper import StorageHelper local_path = StorageHelper.get_local_copy(self.url) if local_path and extract_archive and self.type == 'archive': + temp_folder = None try: temp_folder = mkdtemp(prefix='artifact_', suffix='.archive_'+self.name) ZipFile(local_path).extractall(path=temp_folder) except Exception: try: - Path(temp_folder).rmdir() + if temp_folder: + Path(temp_folder).rmdir() except Exception: pass return local_path diff --git a/trains/task.py b/trains/task.py index 330c3232..fe9b194c 100644 --- a/trains/task.py +++ b/trains/task.py @@ -803,6 +803,20 @@ class Task(_Task): return self._artifacts_manager.upload_artifact(name=name, artifact_object=artifact_object, metadata=metadata, delete_after_upload=delete_after_upload) + def get_models(self): + """ + Return a dictionary with {'input': [], 'output': []} loaded/stored models of the current Task. + Input models are files loaded in the task, either manually or automatically logged + Output models are files stored in the task, either manually or automatically logged + Automatically logged frameworks are for example: TensorFlow, Keras, PyTorch, ScikitLearn(joblib) etc. + + :return dict: dict with keys input/output, each is list of Model objects. + Example: {'input': [trains.Model()], 'output': [trains.Model()]} + """ + task_models = {'input': self._get_models(model_type='input'), + 'output': self._get_models(model_type='output')} + return task_models + def is_current_task(self): """ Check if this task is the main task (returned by Task.init())