mirror of
https://github.com/clearml/clearml
synced 2025-01-31 17:17:00 +00:00
Add Task.get_models() retrieving stored models on previously executed tasks
This commit is contained in:
parent
332e9e2f63
commit
023f1721c1
@ -2,6 +2,7 @@
|
||||
import itertools
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
from enum import Enum
|
||||
from tempfile import gettempdir
|
||||
from multiprocessing import RLock
|
||||
@ -13,6 +14,7 @@ except ImportError:
|
||||
from collections import Iterable
|
||||
|
||||
import six
|
||||
from collections import OrderedDict
|
||||
from six.moves.urllib.parse import quote
|
||||
|
||||
from ...utilities.locks import RLock as FileRLock
|
||||
@ -560,7 +562,7 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
|
||||
ready=True,
|
||||
page=0,
|
||||
page_size=10,
|
||||
order_by='-created',
|
||||
order_by=['-created'],
|
||||
only_fields=['id', 'created']
|
||||
)
|
||||
)
|
||||
@ -847,6 +849,44 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
|
||||
"""
|
||||
return self._initial_iteration_offset
|
||||
|
||||
def _get_models(self, model_type='output'):
|
||||
model_type = model_type.lower().strip()
|
||||
assert model_type == 'output' or model_type == 'input'
|
||||
|
||||
if model_type == 'input':
|
||||
regex = '((?i)(Using model id: )(\w+)?)'
|
||||
compiled = re.compile(regex)
|
||||
ids = [i[-1] for i in re.findall(compiled, self.comment)] + (
|
||||
[self.input_model_id] if self.input_model_id else [])
|
||||
# remove duplicates and preserve order
|
||||
ids = list(OrderedDict.fromkeys(ids))
|
||||
from ...model import Model as TrainsModel
|
||||
in_model = []
|
||||
for i in ids:
|
||||
m = TrainsModel(model_id=i)
|
||||
try:
|
||||
# make sure the model is is valid
|
||||
m._get_model_data()
|
||||
in_model.append(m)
|
||||
except:
|
||||
pass
|
||||
return in_model
|
||||
else:
|
||||
res = self.send(
|
||||
models.GetAllRequest(
|
||||
task=[self.id],
|
||||
order_by=['created'],
|
||||
only_fields=['id']
|
||||
)
|
||||
)
|
||||
if not res.response.models:
|
||||
return []
|
||||
ids = [m.id for m in res.response.models] + ([self.output_model_id] if self.output_model_id else [])
|
||||
# remove duplicates and preserve order
|
||||
ids = list(OrderedDict.fromkeys(ids))
|
||||
from ...model import Model as TrainsModel
|
||||
return [TrainsModel(model_id=i) for i in ids]
|
||||
|
||||
def _get_default_report_storage_uri(self):
|
||||
if not self._files_server:
|
||||
self._files_server = Session.get_files_server_host()
|
||||
|
@ -168,12 +168,14 @@ class Artifact(object):
|
||||
from trains.storage.helper import StorageHelper
|
||||
local_path = StorageHelper.get_local_copy(self.url)
|
||||
if local_path and extract_archive and self.type == 'archive':
|
||||
temp_folder = None
|
||||
try:
|
||||
temp_folder = mkdtemp(prefix='artifact_', suffix='.archive_'+self.name)
|
||||
ZipFile(local_path).extractall(path=temp_folder)
|
||||
except Exception:
|
||||
try:
|
||||
Path(temp_folder).rmdir()
|
||||
if temp_folder:
|
||||
Path(temp_folder).rmdir()
|
||||
except Exception:
|
||||
pass
|
||||
return local_path
|
||||
|
@ -803,6 +803,20 @@ class Task(_Task):
|
||||
return self._artifacts_manager.upload_artifact(name=name, artifact_object=artifact_object,
|
||||
metadata=metadata, delete_after_upload=delete_after_upload)
|
||||
|
||||
def get_models(self):
|
||||
"""
|
||||
Return a dictionary with {'input': [], 'output': []} loaded/stored models of the current Task.
|
||||
Input models are files loaded in the task, either manually or automatically logged
|
||||
Output models are files stored in the task, either manually or automatically logged
|
||||
Automatically logged frameworks are for example: TensorFlow, Keras, PyTorch, ScikitLearn(joblib) etc.
|
||||
|
||||
:return dict: dict with keys input/output, each is list of Model objects.
|
||||
Example: {'input': [trains.Model()], 'output': [trains.Model()]}
|
||||
"""
|
||||
task_models = {'input': self._get_models(model_type='input'),
|
||||
'output': self._get_models(model_type='output')}
|
||||
return task_models
|
||||
|
||||
def is_current_task(self):
|
||||
"""
|
||||
Check if this task is the main task (returned by Task.init())
|
||||
|
Loading…
Reference in New Issue
Block a user