mirror of
https://github.com/clearml/clearml
synced 2025-06-26 18:16:07 +00:00
Add Task.get_models() retrieving stored models on previously executed tasks
This commit is contained in:
parent
332e9e2f63
commit
023f1721c1
@ -2,6 +2,7 @@
|
|||||||
import itertools
|
import itertools
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
|
import re
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from tempfile import gettempdir
|
from tempfile import gettempdir
|
||||||
from multiprocessing import RLock
|
from multiprocessing import RLock
|
||||||
@ -13,6 +14,7 @@ except ImportError:
|
|||||||
from collections import Iterable
|
from collections import Iterable
|
||||||
|
|
||||||
import six
|
import six
|
||||||
|
from collections import OrderedDict
|
||||||
from six.moves.urllib.parse import quote
|
from six.moves.urllib.parse import quote
|
||||||
|
|
||||||
from ...utilities.locks import RLock as FileRLock
|
from ...utilities.locks import RLock as FileRLock
|
||||||
@ -560,7 +562,7 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
|
|||||||
ready=True,
|
ready=True,
|
||||||
page=0,
|
page=0,
|
||||||
page_size=10,
|
page_size=10,
|
||||||
order_by='-created',
|
order_by=['-created'],
|
||||||
only_fields=['id', 'created']
|
only_fields=['id', 'created']
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
@ -847,6 +849,44 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
|
|||||||
"""
|
"""
|
||||||
return self._initial_iteration_offset
|
return self._initial_iteration_offset
|
||||||
|
|
||||||
|
def _get_models(self, model_type='output'):
|
||||||
|
model_type = model_type.lower().strip()
|
||||||
|
assert model_type == 'output' or model_type == 'input'
|
||||||
|
|
||||||
|
if model_type == 'input':
|
||||||
|
regex = '((?i)(Using model id: )(\w+)?)'
|
||||||
|
compiled = re.compile(regex)
|
||||||
|
ids = [i[-1] for i in re.findall(compiled, self.comment)] + (
|
||||||
|
[self.input_model_id] if self.input_model_id else [])
|
||||||
|
# remove duplicates and preserve order
|
||||||
|
ids = list(OrderedDict.fromkeys(ids))
|
||||||
|
from ...model import Model as TrainsModel
|
||||||
|
in_model = []
|
||||||
|
for i in ids:
|
||||||
|
m = TrainsModel(model_id=i)
|
||||||
|
try:
|
||||||
|
# make sure the model is is valid
|
||||||
|
m._get_model_data()
|
||||||
|
in_model.append(m)
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
return in_model
|
||||||
|
else:
|
||||||
|
res = self.send(
|
||||||
|
models.GetAllRequest(
|
||||||
|
task=[self.id],
|
||||||
|
order_by=['created'],
|
||||||
|
only_fields=['id']
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if not res.response.models:
|
||||||
|
return []
|
||||||
|
ids = [m.id for m in res.response.models] + ([self.output_model_id] if self.output_model_id else [])
|
||||||
|
# remove duplicates and preserve order
|
||||||
|
ids = list(OrderedDict.fromkeys(ids))
|
||||||
|
from ...model import Model as TrainsModel
|
||||||
|
return [TrainsModel(model_id=i) for i in ids]
|
||||||
|
|
||||||
def _get_default_report_storage_uri(self):
|
def _get_default_report_storage_uri(self):
|
||||||
if not self._files_server:
|
if not self._files_server:
|
||||||
self._files_server = Session.get_files_server_host()
|
self._files_server = Session.get_files_server_host()
|
||||||
|
@ -168,12 +168,14 @@ class Artifact(object):
|
|||||||
from trains.storage.helper import StorageHelper
|
from trains.storage.helper import StorageHelper
|
||||||
local_path = StorageHelper.get_local_copy(self.url)
|
local_path = StorageHelper.get_local_copy(self.url)
|
||||||
if local_path and extract_archive and self.type == 'archive':
|
if local_path and extract_archive and self.type == 'archive':
|
||||||
|
temp_folder = None
|
||||||
try:
|
try:
|
||||||
temp_folder = mkdtemp(prefix='artifact_', suffix='.archive_'+self.name)
|
temp_folder = mkdtemp(prefix='artifact_', suffix='.archive_'+self.name)
|
||||||
ZipFile(local_path).extractall(path=temp_folder)
|
ZipFile(local_path).extractall(path=temp_folder)
|
||||||
except Exception:
|
except Exception:
|
||||||
try:
|
try:
|
||||||
Path(temp_folder).rmdir()
|
if temp_folder:
|
||||||
|
Path(temp_folder).rmdir()
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
return local_path
|
return local_path
|
||||||
|
@ -803,6 +803,20 @@ class Task(_Task):
|
|||||||
return self._artifacts_manager.upload_artifact(name=name, artifact_object=artifact_object,
|
return self._artifacts_manager.upload_artifact(name=name, artifact_object=artifact_object,
|
||||||
metadata=metadata, delete_after_upload=delete_after_upload)
|
metadata=metadata, delete_after_upload=delete_after_upload)
|
||||||
|
|
||||||
|
def get_models(self):
|
||||||
|
"""
|
||||||
|
Return a dictionary with {'input': [], 'output': []} loaded/stored models of the current Task.
|
||||||
|
Input models are files loaded in the task, either manually or automatically logged
|
||||||
|
Output models are files stored in the task, either manually or automatically logged
|
||||||
|
Automatically logged frameworks are for example: TensorFlow, Keras, PyTorch, ScikitLearn(joblib) etc.
|
||||||
|
|
||||||
|
:return dict: dict with keys input/output, each is list of Model objects.
|
||||||
|
Example: {'input': [trains.Model()], 'output': [trains.Model()]}
|
||||||
|
"""
|
||||||
|
task_models = {'input': self._get_models(model_type='input'),
|
||||||
|
'output': self._get_models(model_type='output')}
|
||||||
|
return task_models
|
||||||
|
|
||||||
def is_current_task(self):
|
def is_current_task(self):
|
||||||
"""
|
"""
|
||||||
Check if this task is the main task (returned by Task.init())
|
Check if this task is the main task (returned by Task.init())
|
||||||
|
Loading…
Reference in New Issue
Block a user