Add Task.get_models() retrieving stored models on previously executed tasks

This commit is contained in:
allegroai 2020-03-22 18:19:07 +02:00
parent 332e9e2f63
commit 023f1721c1
3 changed files with 58 additions and 2 deletions

View File

@ -2,6 +2,7 @@
import itertools
import logging
import os
import re
from enum import Enum
from tempfile import gettempdir
from multiprocessing import RLock
@ -13,6 +14,7 @@ except ImportError:
from collections import Iterable
import six
from collections import OrderedDict
from six.moves.urllib.parse import quote
from ...utilities.locks import RLock as FileRLock
@ -560,7 +562,7 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
ready=True,
page=0,
page_size=10,
order_by='-created',
order_by=['-created'],
only_fields=['id', 'created']
)
)
@ -847,6 +849,44 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
"""
return self._initial_iteration_offset
def _get_models(self, model_type='output'):
model_type = model_type.lower().strip()
assert model_type == 'output' or model_type == 'input'
if model_type == 'input':
regex = '((?i)(Using model id: )(\w+)?)'
compiled = re.compile(regex)
ids = [i[-1] for i in re.findall(compiled, self.comment)] + (
[self.input_model_id] if self.input_model_id else [])
# remove duplicates and preserve order
ids = list(OrderedDict.fromkeys(ids))
from ...model import Model as TrainsModel
in_model = []
for i in ids:
m = TrainsModel(model_id=i)
try:
# make sure the model is is valid
m._get_model_data()
in_model.append(m)
except:
pass
return in_model
else:
res = self.send(
models.GetAllRequest(
task=[self.id],
order_by=['created'],
only_fields=['id']
)
)
if not res.response.models:
return []
ids = [m.id for m in res.response.models] + ([self.output_model_id] if self.output_model_id else [])
# remove duplicates and preserve order
ids = list(OrderedDict.fromkeys(ids))
from ...model import Model as TrainsModel
return [TrainsModel(model_id=i) for i in ids]
def _get_default_report_storage_uri(self):
if not self._files_server:
self._files_server = Session.get_files_server_host()

View File

@ -168,12 +168,14 @@ class Artifact(object):
from trains.storage.helper import StorageHelper
local_path = StorageHelper.get_local_copy(self.url)
if local_path and extract_archive and self.type == 'archive':
temp_folder = None
try:
temp_folder = mkdtemp(prefix='artifact_', suffix='.archive_'+self.name)
ZipFile(local_path).extractall(path=temp_folder)
except Exception:
try:
Path(temp_folder).rmdir()
if temp_folder:
Path(temp_folder).rmdir()
except Exception:
pass
return local_path

View File

@ -803,6 +803,20 @@ class Task(_Task):
return self._artifacts_manager.upload_artifact(name=name, artifact_object=artifact_object,
metadata=metadata, delete_after_upload=delete_after_upload)
def get_models(self):
"""
Return a dictionary with {'input': [], 'output': []} loaded/stored models of the current Task.
Input models are files loaded in the task, either manually or automatically logged
Output models are files stored in the task, either manually or automatically logged
Automatically logged frameworks are for example: TensorFlow, Keras, PyTorch, ScikitLearn(joblib) etc.
:return dict: dict with keys input/output, each is list of Model objects.
Example: {'input': [trains.Model()], 'output': [trains.Model()]}
"""
task_models = {'input': self._get_models(model_type='input'),
'output': self._get_models(model_type='output')}
return task_models
def is_current_task(self):
"""
Check if this task is the main task (returned by Task.init())