Support ClearML server API v2.13

This commit is contained in:
allegroai 2021-04-10 22:32:49 +03:00
parent 139c4ffe86
commit 86148a928b
17 changed files with 30090 additions and 369 deletions

View File

@ -0,0 +1,618 @@
"""
auth service
This service provides authentication management and authorization
validation for the entire system.
"""
from datetime import datetime
import six
from clearml.backend_api.session import (
Request,
Response,
NonStrictDataModel,
schema_property,
)
from dateutil.parser import parse as parse_datetime
class Credentials(NonStrictDataModel):
"""
:param access_key: Credentials access key
:type access_key: str
:param secret_key: Credentials secret key
:type secret_key: str
"""
_schema = {
"properties": {
"access_key": {
"description": "Credentials access key",
"type": ["string", "null"],
},
"secret_key": {
"description": "Credentials secret key",
"type": ["string", "null"],
},
},
"type": "object",
}
def __init__(self, access_key=None, secret_key=None, **kwargs):
super(Credentials, self).__init__(**kwargs)
self.access_key = access_key
self.secret_key = secret_key
@schema_property("access_key")
def access_key(self):
return self._property_access_key
@access_key.setter
def access_key(self, value):
if value is None:
self._property_access_key = None
return
self.assert_isinstance(value, "access_key", six.string_types)
self._property_access_key = value
@schema_property("secret_key")
def secret_key(self):
return self._property_secret_key
@secret_key.setter
def secret_key(self, value):
if value is None:
self._property_secret_key = None
return
self.assert_isinstance(value, "secret_key", six.string_types)
self._property_secret_key = value
class CredentialKey(NonStrictDataModel):
"""
:param access_key:
:type access_key: str
:param last_used:
:type last_used: datetime.datetime
:param last_used_from:
:type last_used_from: str
"""
_schema = {
"properties": {
"access_key": {"description": "", "type": ["string", "null"]},
"last_used": {
"description": "",
"format": "date-time",
"type": ["string", "null"],
},
"last_used_from": {"description": "", "type": ["string", "null"]},
},
"type": "object",
}
def __init__(self, access_key=None, last_used=None, last_used_from=None, **kwargs):
super(CredentialKey, self).__init__(**kwargs)
self.access_key = access_key
self.last_used = last_used
self.last_used_from = last_used_from
@schema_property("access_key")
def access_key(self):
return self._property_access_key
@access_key.setter
def access_key(self, value):
if value is None:
self._property_access_key = None
return
self.assert_isinstance(value, "access_key", six.string_types)
self._property_access_key = value
@schema_property("last_used")
def last_used(self):
return self._property_last_used
@last_used.setter
def last_used(self, value):
if value is None:
self._property_last_used = None
return
self.assert_isinstance(value, "last_used", six.string_types + (datetime,))
if not isinstance(value, datetime):
value = parse_datetime(value)
self._property_last_used = value
@schema_property("last_used_from")
def last_used_from(self):
return self._property_last_used_from
@last_used_from.setter
def last_used_from(self, value):
if value is None:
self._property_last_used_from = None
return
self.assert_isinstance(value, "last_used_from", six.string_types)
self._property_last_used_from = value
class CreateCredentialsRequest(Request):
"""
Creates a new set of credentials for the authenticated user.
New key/secret is returned.
Note: Secret will never be returned in any other API call.
If a secret is lost or compromised, the key should be revoked
and a new set of credentials can be created.
"""
_service = "auth"
_action = "create_credentials"
_version = "2.13"
_schema = {
"additionalProperties": False,
"definitions": {},
"properties": {},
"type": "object",
}
class CreateCredentialsResponse(Response):
"""
Response of auth.create_credentials endpoint.
:param credentials: Created credentials
:type credentials: Credentials
"""
_service = "auth"
_action = "create_credentials"
_version = "2.13"
_schema = {
"definitions": {
"credentials": {
"properties": {
"access_key": {
"description": "Credentials access key",
"type": ["string", "null"],
},
"secret_key": {
"description": "Credentials secret key",
"type": ["string", "null"],
},
},
"type": "object",
}
},
"properties": {
"credentials": {
"description": "Created credentials",
"oneOf": [{"$ref": "#/definitions/credentials"}, {"type": "null"}],
}
},
"type": "object",
}
def __init__(self, credentials=None, **kwargs):
super(CreateCredentialsResponse, self).__init__(**kwargs)
self.credentials = credentials
@schema_property("credentials")
def credentials(self):
return self._property_credentials
@credentials.setter
def credentials(self, value):
if value is None:
self._property_credentials = None
return
if isinstance(value, dict):
value = Credentials.from_dict(value)
else:
self.assert_isinstance(value, "credentials", Credentials)
self._property_credentials = value
class EditUserRequest(Request):
"""
Edit a users' auth data properties
:param user: User ID
:type user: str
:param role: The new user's role within the company
:type role: str
"""
_service = "auth"
_action = "edit_user"
_version = "2.13"
_schema = {
"definitions": {},
"properties": {
"role": {
"description": "The new user's role within the company",
"enum": ["admin", "superuser", "user", "annotator"],
"type": ["string", "null"],
},
"user": {"description": "User ID", "type": ["string", "null"]},
},
"type": "object",
}
def __init__(self, user=None, role=None, **kwargs):
super(EditUserRequest, self).__init__(**kwargs)
self.user = user
self.role = role
@schema_property("user")
def user(self):
return self._property_user
@user.setter
def user(self, value):
if value is None:
self._property_user = None
return
self.assert_isinstance(value, "user", six.string_types)
self._property_user = value
@schema_property("role")
def role(self):
return self._property_role
@role.setter
def role(self, value):
if value is None:
self._property_role = None
return
self.assert_isinstance(value, "role", six.string_types)
self._property_role = value
class EditUserResponse(Response):
"""
Response of auth.edit_user endpoint.
:param updated: Number of users updated (0 or 1)
:type updated: float
:param fields: Updated fields names and values
:type fields: dict
"""
_service = "auth"
_action = "edit_user"
_version = "2.13"
_schema = {
"definitions": {},
"properties": {
"fields": {
"additionalProperties": True,
"description": "Updated fields names and values",
"type": ["object", "null"],
},
"updated": {
"description": "Number of users updated (0 or 1)",
"enum": [0, 1],
"type": ["number", "null"],
},
},
"type": "object",
}
def __init__(self, updated=None, fields=None, **kwargs):
super(EditUserResponse, self).__init__(**kwargs)
self.updated = updated
self.fields = fields
@schema_property("updated")
def updated(self):
return self._property_updated
@updated.setter
def updated(self, value):
if value is None:
self._property_updated = None
return
self.assert_isinstance(value, "updated", six.integer_types + (float,))
self._property_updated = value
@schema_property("fields")
def fields(self):
return self._property_fields
@fields.setter
def fields(self, value):
if value is None:
self._property_fields = None
return
self.assert_isinstance(value, "fields", (dict,))
self._property_fields = value
class GetCredentialsRequest(Request):
"""
Returns all existing credential keys for the authenticated user.
Note: Only credential keys are returned.
"""
_service = "auth"
_action = "get_credentials"
_version = "2.13"
_schema = {
"additionalProperties": False,
"definitions": {},
"properties": {},
"type": "object",
}
class GetCredentialsResponse(Response):
"""
Response of auth.get_credentials endpoint.
:param credentials: List of credentials for the user own company, each with an
empty secret field.
:type credentials: Sequence[CredentialKey]
:param additional_credentials: The user credentials for the user tenant
companies, each with an empty secret field.
:type additional_credentials: dict
"""
_service = "auth"
_action = "get_credentials"
_version = "2.13"
_schema = {
"definitions": {
"credential_key": {
"properties": {
"access_key": {"description": "", "type": ["string", "null"]},
"last_used": {
"description": "",
"format": "date-time",
"type": ["string", "null"],
},
"last_used_from": {"description": "", "type": ["string", "null"]},
},
"type": "object",
}
},
"properties": {
"additional_credentials": {
"additionalProperties": True,
"description": "The user credentials for the user tenant companies, each with an empty secret field.",
"type": ["object", "null"],
},
"credentials": {
"description": "List of credentials for the user own company, each with an empty secret field.",
"items": {"$ref": "#/definitions/credential_key"},
"type": ["array", "null"],
},
},
"type": "object",
}
def __init__(self, credentials=None, additional_credentials=None, **kwargs):
super(GetCredentialsResponse, self).__init__(**kwargs)
self.credentials = credentials
self.additional_credentials = additional_credentials
@schema_property("credentials")
def credentials(self):
return self._property_credentials
@credentials.setter
def credentials(self, value):
if value is None:
self._property_credentials = None
return
self.assert_isinstance(value, "credentials", (list, tuple))
if any(isinstance(v, dict) for v in value):
value = [
CredentialKey.from_dict(v) if isinstance(v, dict) else v for v in value
]
else:
self.assert_isinstance(value, "credentials", CredentialKey, is_array=True)
self._property_credentials = value
@schema_property("additional_credentials")
def additional_credentials(self):
return self._property_additional_credentials
@additional_credentials.setter
def additional_credentials(self, value):
if value is None:
self._property_additional_credentials = None
return
self.assert_isinstance(value, "additional_credentials", (dict,))
self._property_additional_credentials = value
class LoginRequest(Request):
"""
Get a token based on supplied credentials (key/secret).
Intended for use by users with key/secret credentials that wish to obtain a token
for use with other services. Token will be limited by the same permissions that
exist for the credentials used in this call.
:param expiration_sec: Requested token expiration time in seconds. Not
guaranteed, might be overridden by the service
:type expiration_sec: int
"""
_service = "auth"
_action = "login"
_version = "2.13"
_schema = {
"definitions": {},
"properties": {
"expiration_sec": {
"description": "Requested token expiration time in seconds. \n Not guaranteed, might be overridden by the service",
"type": ["integer", "null"],
}
},
"type": "object",
}
def __init__(self, expiration_sec=None, **kwargs):
super(LoginRequest, self).__init__(**kwargs)
self.expiration_sec = expiration_sec
@schema_property("expiration_sec")
def expiration_sec(self):
return self._property_expiration_sec
@expiration_sec.setter
def expiration_sec(self, value):
if value is None:
self._property_expiration_sec = None
return
if isinstance(value, float) and value.is_integer():
value = int(value)
self.assert_isinstance(value, "expiration_sec", six.integer_types)
self._property_expiration_sec = value
class LoginResponse(Response):
"""
Response of auth.login endpoint.
:param token: Token string
:type token: str
"""
_service = "auth"
_action = "login"
_version = "2.13"
_schema = {
"definitions": {},
"properties": {
"token": {"description": "Token string", "type": ["string", "null"]}
},
"type": "object",
}
def __init__(self, token=None, **kwargs):
super(LoginResponse, self).__init__(**kwargs)
self.token = token
@schema_property("token")
def token(self):
return self._property_token
@token.setter
def token(self, value):
if value is None:
self._property_token = None
return
self.assert_isinstance(value, "token", six.string_types)
self._property_token = value
class RevokeCredentialsRequest(Request):
"""
Revokes (and deletes) a set (key, secret) of credentials for
the authenticated user.
:param access_key: Credentials key
:type access_key: str
"""
_service = "auth"
_action = "revoke_credentials"
_version = "2.13"
_schema = {
"definitions": {},
"properties": {
"access_key": {"description": "Credentials key", "type": ["string", "null"]}
},
"required": ["key_id"],
"type": "object",
}
def __init__(self, access_key=None, **kwargs):
super(RevokeCredentialsRequest, self).__init__(**kwargs)
self.access_key = access_key
@schema_property("access_key")
def access_key(self):
return self._property_access_key
@access_key.setter
def access_key(self, value):
if value is None:
self._property_access_key = None
return
self.assert_isinstance(value, "access_key", six.string_types)
self._property_access_key = value
class RevokeCredentialsResponse(Response):
"""
Response of auth.revoke_credentials endpoint.
:param revoked: Number of credentials revoked
:type revoked: int
"""
_service = "auth"
_action = "revoke_credentials"
_version = "2.13"
_schema = {
"definitions": {},
"properties": {
"revoked": {
"description": "Number of credentials revoked",
"enum": [0, 1],
"type": ["integer", "null"],
}
},
"type": "object",
}
def __init__(self, revoked=None, **kwargs):
super(RevokeCredentialsResponse, self).__init__(**kwargs)
self.revoked = revoked
@schema_property("revoked")
def revoked(self):
return self._property_revoked
@revoked.setter
def revoked(self, value):
if value is None:
self._property_revoked = None
return
if isinstance(value, float) and value.is_integer():
value = int(value)
self.assert_isinstance(value, "revoked", six.integer_types)
self._property_revoked = value
response_mapping = {
LoginRequest: LoginResponse,
CreateCredentialsRequest: CreateCredentialsResponse,
GetCredentialsRequest: GetCredentialsResponse,
RevokeCredentialsRequest: RevokeCredentialsResponse,
EditUserRequest: EditUserResponse,
}

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,155 @@
"""
organization service
This service provides organization level operations
"""
import six
from clearml.backend_api.session import Request, Response, schema_property
class GetTagsRequest(Request):
"""
Get all the user and system tags used for the company tasks and models
:param include_system: If set to 'true' then the list of the system tags is
also returned. The default value is 'false'
:type include_system: bool
:param filter: Filter on entities to collect tags from
:type filter: dict
"""
_service = "organization"
_action = "get_tags"
_version = "2.13"
_schema = {
"definitions": {},
"properties": {
"filter": {
"description": "Filter on entities to collect tags from",
"properties": {
"system_tags": {
"description": "The list of system tag values to filter by. Use 'null' value to specify empty system tags. Use '__Snot' value to specify that the following value should be excluded",
"items": {"type": "string"},
"type": "array",
},
"tags": {
"description": "The list of tag values to filter by. Use 'null' value to specify empty tags. Use '__Snot' value to specify that the following value should be excluded",
"items": {"type": "string"},
"type": "array",
},
},
"type": ["object", "null"],
},
"include_system": {
"default": False,
"description": "If set to 'true' then the list of the system tags is also returned. The default value is 'false'",
"type": ["boolean", "null"],
},
},
"type": "object",
}
def __init__(self, include_system=False, filter=None, **kwargs):
super(GetTagsRequest, self).__init__(**kwargs)
self.include_system = include_system
self.filter = filter
@schema_property("include_system")
def include_system(self):
return self._property_include_system
@include_system.setter
def include_system(self, value):
if value is None:
self._property_include_system = None
return
self.assert_isinstance(value, "include_system", (bool,))
self._property_include_system = value
@schema_property("filter")
def filter(self):
return self._property_filter
@filter.setter
def filter(self, value):
if value is None:
self._property_filter = None
return
self.assert_isinstance(value, "filter", (dict,))
self._property_filter = value
class GetTagsResponse(Response):
"""
Response of organization.get_tags endpoint.
:param tags: The list of unique tag values
:type tags: Sequence[str]
:param system_tags: The list of unique system tag values. Returned only if
'include_system' is set to 'true' in the request
:type system_tags: Sequence[str]
"""
_service = "organization"
_action = "get_tags"
_version = "2.13"
_schema = {
"definitions": {},
"properties": {
"system_tags": {
"description": "The list of unique system tag values. Returned only if 'include_system' is set to 'true' in the request",
"items": {"type": "string"},
"type": ["array", "null"],
},
"tags": {
"description": "The list of unique tag values",
"items": {"type": "string"},
"type": ["array", "null"],
},
},
"type": "object",
}
def __init__(self, tags=None, system_tags=None, **kwargs):
super(GetTagsResponse, self).__init__(**kwargs)
self.tags = tags
self.system_tags = system_tags
@schema_property("tags")
def tags(self):
return self._property_tags
@tags.setter
def tags(self, value):
if value is None:
self._property_tags = None
return
self.assert_isinstance(value, "tags", (list, tuple))
self.assert_isinstance(value, "tags", six.string_types, is_array=True)
self._property_tags = value
@schema_property("system_tags")
def system_tags(self):
return self._property_system_tags
@system_tags.setter
def system_tags(self, value):
if value is None:
self._property_system_tags = None
return
self.assert_isinstance(value, "system_tags", (list, tuple))
self.assert_isinstance(value, "system_tags", six.string_types, is_array=True)
self._property_system_tags = value
response_mapping = {
GetTagsRequest: GetTagsResponse,
}

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -7,7 +7,7 @@ from pathlib2 import Path
from .base import IdObjectBase from .base import IdObjectBase
from .util import make_message from .util import make_message
from ..backend_api import Session from ..backend_api import Session
from ..backend_api.services import models from ..backend_api.services import models, tasks
from ..storage import StorageManager from ..storage import StorageManager
from ..storage.helper import StorageHelper from ..storage.helper import StorageHelper
from ..utilities.async_manager import AsyncManagerMixin from ..utilities.async_manager import AsyncManagerMixin
@ -32,13 +32,14 @@ class _StorageUriMixin(object):
def create_dummy_model(upload_storage_uri=None, *args, **kwargs): def create_dummy_model(upload_storage_uri=None, *args, **kwargs):
class DummyModel(models.Model, _StorageUriMixin): class DummyModel(models.Model, _StorageUriMixin):
def __init__(self, upload_storage_uri=None, *args, **kwargs): def __init__(self, upload_storage_uri=None, *_, **__):
super(DummyModel, self).__init__(*args, **kwargs) super(DummyModel, self).__init__(*_, **__)
self.upload_storage_uri = upload_storage_uri self.upload_storage_uri = upload_storage_uri
def update(self, **kwargs): def update(self, **a_kwargs):
for k, v in kwargs.items(): for k, v in a_kwargs.items():
setattr(self, k, v) setattr(self, k, v)
return DummyModel(upload_storage_uri=upload_storage_uri, *args, **kwargs) return DummyModel(upload_storage_uri=upload_storage_uri, *args, **kwargs)
@ -300,58 +301,24 @@ class Model(IdObjectBase, AsyncManagerMixin, _StorageUriMixin):
return uri return uri
def _complete_update_for_task(self, uri, task_id=None, name=None, comment=None, tags=None, override_model_id=None, def update_for_task(self, task_id, name=None, model_id=None, type_="output", iteration=None):
cb=None): if Session.check_min_api_version("2.13"):
if self._data: req = tasks.AddOrUpdateModelRequest(
name = name or self.data.name task=task_id, name=name, type=type_, model=model_id, iteration=iteration
comment = comment or self.data.comment )
tags = tags or (self.data.system_tags if hasattr(self.data, 'system_tags') else self.data.tags) elif type_ == "output":
uri = (uri or self.data.uri) if not override_model_id else None # backwards compatibility
req = models.UpdateForTaskRequest(task=task_id, override_model_id=model_id)
if tags: elif type_ == "input":
extra = {'system_tags': tags} if Session.check_min_api_version('2.3') else {'tags': tags} # backwards compatibility, None
req = None
else: else:
extra = {} raise ValueError("Type '{}' unsupported (use either 'input' or 'output')".format(type_))
res = self.send(
models.UpdateForTaskRequest(task=task_id, uri=uri, name=name, comment=comment,
override_model_id=override_model_id, **extra))
if self.id is None:
# update the model id. in case it was just created, this will trigger a reload of the model object
self.id = res.response.id if res else None
else:
self.reload()
try:
if cb:
cb(uri)
except Exception as ex:
self.log.warning('Failed calling callback on complete_update_for_task: %s' % str(ex))
pass
def update_for_task_and_upload( if req:
self, model_file, task_id, name=None, comment=None, tags=None, override_model_id=None, target_filename=None, self.send(req)
async_enable=False, cb=None, iteration=None):
""" Update the given model for a given task ID """
if async_enable:
callback = partial(
self._complete_update_for_task, task_id=task_id, name=name, comment=comment, tags=tags,
override_model_id=override_model_id, cb=cb)
uri = self._upload_model(model_file, target_filename=target_filename,
async_enable=async_enable, cb=callback)
return uri
else:
uri = self._upload_model(model_file, target_filename=target_filename, async_enable=async_enable)
self._complete_update_for_task(uri, task_id, name, comment, tags, override_model_id)
if tags:
extra = {'system_tags': tags} if Session.check_min_api_version('2.3') else {'tags': tags}
else:
extra = {}
_ = self.send(models.UpdateForTaskRequest(task=task_id, uri=uri, name=name, comment=comment,
override_model_id=override_model_id, iteration=iteration,
**extra))
return uri
def update_for_task(self, task_id, uri=None, name=None, comment=None, tags=None, override_model_id=None): self.reload()
self._complete_update_for_task(uri, task_id, name, comment, tags, override_model_id)
@property @property
def model_design(self): def model_design(self):
@ -480,6 +447,9 @@ class Model(IdObjectBase, AsyncManagerMixin, _StorageUriMixin):
:param name: Name for the new model :param name: Name for the new model
:param comment: Optional comment for the new model :param comment: Optional comment for the new model
:param child: Should the new model be a child of this model (default True) :param child: Should the new model be a child of this model (default True)
:param tags: Optional tags for the cloned model
:param task: Creating Task of the Model
:param ready: If True set the true flag for the newly created model
:return: The new model's ID :return: The new model's ID
""" """
data = self.data data = self.data

View File

@ -0,0 +1,117 @@
import re
import typing
from collections import UserDict, OrderedDict, UserList
from clearml.backend_api import Session
from clearml.backend_api.services import models
class ModelsList(UserList):
def __init__(self, models_dict):
# type: (typing.OrderedDict["clearml.Model"]) -> None
self._models = models_dict
super(ModelsList, self).__init__(models_dict.values())
def __getitem__(self, item):
if isinstance(item, str):
return self._models[item]
return super(ModelsList, self).__getitem__(item)
def get(self, key, default=None):
try:
return self[key]
except KeyError:
return default
class TaskModels(UserDict):
_input_models_re = re.compile(r"((?i)(Using model id: )(\w+)?)")
@property
def input(self):
# type: () -> ModelsList
return self._input
@property
def output(self):
# type: () -> ModelsList
return self._output
def __init__(self, task):
# type: ("clearml.Task") -> None
self._input = self._get_input_models(task)
self._output = self._get_output_models(task)
super(TaskModels, self).__init__({"input": self._input, "output": self._output})
def _get_input_models(self, task):
# type: ("clearml.Task") -> ModelsList
if Session.check_min_api_version("2.13"):
parsed_ids = list(task.input_models_id.values())
else:
# since we'll fall back to the new task.models.input if no parsed IDs are found, only
# extend this with the input model in case we're using 2.13 and have any parsed IDs or if we're using
# a lower API version.
parsed_ids = [i[-1] for i in self._input_models_re.findall(task.comment)]
# get the last one on the Task
parsed_ids.extend(list(task.input_models_id.values()))
from clearml.model import Model
def get_model(id_):
m = Model(model_id=id_)
# noinspection PyBroadException
try:
# make sure the model is is valid
# noinspection PyProtectedMember
m._get_model_data()
return m
except Exception:
pass
# remove duplicates and preserve order
input_models = OrderedDict(
(m_id, "Input Model #{}".format(i))
for i, m_id in enumerate(
filter(None, map(get_model, OrderedDict.fromkeys(parsed_ids)))
)
)
if not input_models and Session.check_min_api_version("2.13"):
# Only new 2.13 task.models.input in case we have no parsed models
input_models = OrderedDict(
(x.name, get_model(x.model)) for x in task.data.models.input
)
return ModelsList(input_models)
@staticmethod
def _get_output_models(task):
# type: ("clearml.Task") -> ModelsList
res = task.send(
models.GetAllRequest(
task=[task.id], order_by=["created"], only_fields=["id"]
)
)
ids = [m.id for m in res.response.models or []] + list(task.output_models_id.values())
# remove duplicates and preserve order
ids = list(OrderedDict.fromkeys(ids))
id_to_name = (
{x.model: x.name for x in task.data.models.output}
if Session.check_min_api_version("2.13")
else {}
)
def resolve_name(index, model_id):
return id_to_name.get(model_id, "Output Model #{}".format(index))
from clearml.model import Model
output_models = OrderedDict(
(resolve_name(i, m_id), Model(model_id=m_id)) for i, m_id in enumerate(ids)
)
return ModelsList(output_models)

View File

@ -24,7 +24,6 @@ except ImportError:
from collections import Iterable from collections import Iterable
import six import six
from collections import OrderedDict
from six.moves.urllib.parse import quote from six.moves.urllib.parse import quote
from ...utilities.locks import RLock as FileRLock from ...utilities.locks import RLock as FileRLock
@ -53,7 +52,7 @@ from ...storage.helper import StorageHelper, StorageError
from .access import AccessMixin from .access import AccessMixin
from .repo import ScriptInfo, pip_freeze from .repo import ScriptInfo, pip_freeze
from .hyperparams import HyperParams from .hyperparams import HyperParams
from ...config import config, PROC_MASTER_ID_ENV_VAR, SUPPRESS_UPDATE_MESSAGE_ENV_VAR from ...config import config, PROC_MASTER_ID_ENV_VAR, SUPPRESS_UPDATE_MESSAGE_ENV_VAR, DOCKER_BASH_SETUP_ENV_VAR
from ...utilities.process.mp import SingletonLock from ...utilities.process.mp import SingletonLock
@ -150,8 +149,6 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
super(Task, self).__init__(id=task_id, session=session, log=log) super(Task, self).__init__(id=task_id, session=session, log=log)
self._project_name = None self._project_name = None
self._storage_uri = None self._storage_uri = None
self._input_model = None
self._output_model = None
self._metrics_manager = None self._metrics_manager = None
self.__reporter = None self.__reporter = None
self._curr_label_stats = {} self._curr_label_stats = {}
@ -334,8 +331,6 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
self._storage_uri = StorageHelper.conform_url(value) self._storage_uri = StorageHelper.conform_url(value)
self.data.output.destination = self._storage_uri self.data.output.destination = self._storage_uri
self._edit(output_dest=self._storage_uri or ('' if Session.check_min_api_version('2.3') else None)) self._edit(output_dest=self._storage_uri or ('' if Session.check_min_api_version('2.3') else None))
if self._storage_uri or self._output_model:
self.output_model.upload_storage_uri = self._storage_uri
@property @property
def storage_uri(self): def storage_uri(self):
@ -383,14 +378,24 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
return self.data.parent return self.data.parent
@property @property
def input_model_id(self): def input_models_id(self):
# type: () -> str # type: () -> Mapping[str, str]
return self.data.execution.model if not Session.check_min_api_version("2.13"):
model_id = self._get_task_property('execution.model', raise_on_error=False)
return {'Input Model': model_id} if model_id else {}
input_models = self._get_task_property('models.input', default=[]) or []
return {m.name: m.model for m in input_models}
@property @property
def output_model_id(self): def output_models_id(self):
# type: () -> str # type: () -> Mapping[str, str]
return self.data.output.model if not Session.check_min_api_version("2.13"):
model_id = self._get_task_property('output.model', raise_on_error=False)
return {'Output Model': model_id} if model_id else {}
output_models = self._get_task_property('models.output', default=[]) or []
return {m.name: m.model for m in output_models}
@property @property
def comment(self): def comment(self):
@ -425,34 +430,6 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
""" Return the task's cached status (don't reload if we don't have to) """ """ Return the task's cached status (don't reload if we don't have to) """
return str(self.data.status) return str(self.data.status)
@property
def input_model(self):
# type: () -> Optional[Model]
""" A model manager used to handle the input model object """
model_id = self._get_task_property('execution.model', raise_on_error=False)
if not model_id:
return None
if self._input_model is None:
self._input_model = Model(
session=self.session,
model_id=model_id,
cache_dir=self.cache_dir,
log=self.log,
upload_storage_uri=None)
return self._input_model
@property
def output_model(self):
# type: () -> Optional[Model]
""" A model manager used to manage the output model object """
if self._output_model is None:
self._output_model = self._get_output_model(upload_required=True)
return self._output_model
def create_output_model(self):
# type: () -> Model
return self._get_output_model(upload_required=False, force=True)
def reload(self): def reload(self):
# type: () -> () # type: () -> ()
""" """
@ -461,12 +438,11 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
""" """
return super(Task, self).reload() return super(Task, self).reload()
def _get_output_model(self, upload_required=True, force=False, model_id=None): def _get_output_model(self, upload_required=True, model_id=None):
# type: (bool, bool, Optional[str]) -> Model # type: (bool, Optional[str]) -> Model
return Model( return Model(
session=self.session, session=self.session,
model_id=model_id or (None if force else self._get_task_property( model_id=model_id or None,
'output.model', raise_on_error=False, log_on_error=False)),
cache_dir=self.cache_dir, cache_dir=self.cache_dir,
upload_storage_uri=self.storage_uri or self.get_output_destination( upload_storage_uri=self.storage_uri or self.get_output_destination(
raise_on_error=upload_required, log_on_error=upload_required), raise_on_error=upload_required, log_on_error=upload_required),
@ -548,7 +524,11 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
def reset(self, set_started_on_success=True): def reset(self, set_started_on_success=True):
# type: (bool) -> () # type: (bool) -> ()
""" Reset the task. Task will be reloaded following a successful reset. """ """
Reset the task. Task will be reloaded following a successful reset.
:param set_started_on_success: If True automatically set Task status to started after resetting it.
"""
self.send(tasks.ResetRequest(task=self.id)) self.send(tasks.ResetRequest(task=self.id))
if set_started_on_success: if set_started_on_success:
self.started() self.started()
@ -748,36 +728,13 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
res = self._edit(execution=execution) res = self._edit(execution=execution)
return res.response return res.response
def update_output_model(self, model_uri, name=None, comment=None, tags=None): def update_output_model(
# type: (str, Optional[str], Optional[str], Optional[Sequence[str]]) -> ()
"""
Update the Task's output model. Use this method to update the output model when you have a local model URI,
for example, storing the weights file locally, and specifying a ``file://path/to/file`` URI)
.. important::
This method only updates the model's metadata using the API. It does not upload any data.
:param model_uri: The URI of the updated model weights file.
:type model_uri: str
:param name: The updated model name. (Optional)
:type name: str
:param comment: The updated model description. (Optional)
:type comment: str
:param tags: The updated model tags. (Optional)
:type tags: [str]
"""
self._conditionally_start_task()
self._get_output_model(upload_required=False).update_for_task(
uri=model_uri, task_id=self.id, name=name, comment=comment, tags=tags)
def update_output_model_and_upload(
self, self,
model_file, # type: str model_path, # type: str
name=None, # type: Optional[str] name=None, # type: Optional[str]
comment=None, # type: Optional[str] comment=None, # type: Optional[str]
tags=None, # type: Optional[Sequence[str]] tags=None, # type: Optional[Sequence[str]]
async_enable=False, # type: bool model_name=None, # type: Optional[str]
cb=None, # type: Optional[Callable[[Optional[bool]], bool]]
iteration=None, # type: Optional[int] iteration=None, # type: Optional[int]
): ):
# type: (...) -> str # type: (...) -> str
@ -787,34 +744,29 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
then ClearML updates the model object associated with the Task an API call. The API call uses with the URI then ClearML updates the model object associated with the Task an API call. The API call uses with the URI
of the uploaded file, and other values provided by additional arguments. of the uploaded file, and other values provided by additional arguments.
:param str model_file: The path to the updated model weights file. :param model_path: A local weights file or folder to be uploaded.
:param str name: The updated model name. (Optional) If remote URI is provided (e.g. http:// or s3: // etc) then the URI is stored as is, without any upload
:param str comment: The updated model description. (Optional) :param name: The updated model name.
:param list tags: The updated model tags. (Optional) If not provided, the name is the model weights file filename without the extension.
:param bool async_enable: Request asynchronous upload :param comment: The updated model description. (Optional)
:param tags: The updated model tags. (Optional)
:param model_name: If provided the model name as it will appear in the model artifactory. (Optional)
Default: Task.name - name
:param iteration: iteration number for the current stored model (Optional)
- ``True`` - The API call returns immediately, while the upload and update are scheduled in another thread. :return: The URI of the uploaded weights file.
- ``False`` - The API call blocks until the upload completes, and the API call updating the model returns. Notice: upload is done is a background thread, while the function call returns immediately
(default)
:param callable cb: Asynchronous callback. A callback. If ``async_enable`` is set to ``True``,
this is a callback that is invoked once the asynchronous upload and update complete.
:param int iteration: iteration number for the current stored model (Optional)
:return: The URI of the uploaded weights file. If ``async_enable`` is set to ``True``,
this is the expected URI, as the upload is probably still in progress.
""" """
self._conditionally_start_task() from ...model import OutputModel
uri = self.output_model.update_for_task_and_upload( output_model = OutputModel(
model_file, self.id, name=name, comment=comment, tags=tags, async_enable=async_enable, cb=cb, task=self,
iteration=iteration name=model_name or ('{} - {}'.format(self.name, name) if name else self.name),
tags=tags,
comment=comment
) )
return uri output_model.connect(task=self, name=name)
url = output_model.update_weights(weights_filename=model_path, iteration=iteration)
def _conditionally_start_task(self): return url
# type: () -> ()
if str(self.status) == str(tasks.TaskStatusEnum.created):
self.started()
@property @property
def labels_stats(self): def labels_stats(self):
@ -832,16 +784,24 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
else: else:
self._curr_label_stats[label] = roi_stats[label] self._curr_label_stats[label] = roi_stats[label]
def set_input_model(self, model_id=None, model_name=None, update_task_design=True, update_task_labels=True): def set_input_model(
# type: (str, Optional[str], bool, bool) -> () self,
model_id=None,
model_name=None,
update_task_design=True,
update_task_labels=True,
name=None
):
# type: (str, Optional[str], bool, bool, Optional[str]) -> ()
""" """
Set a new input model for the Task. The model must be "ready" (status is ``Published``) to be used as the Set a new input model for the Task. The model must be "ready" (status is ``Published``) to be used as the
Task's input model. Task's input model.
:param model_id: The Id of the model on the **ClearML Server** (backend). If ``model_name`` is not specified, :param model_id: The Id of the model on the **ClearML Server** (backend). If ``model_name`` is not specified,
then ``model_id`` must be specified. then ``model_id`` must be specified.
:param model_name: The model name. The name is used to locate an existing model in the **ClearML Server** :param model_name: The model name in the artifactory. The model_name is used to locate an existing model
(backend). If ``model_id`` is not specified, then ``model_name`` must be specified. in the **ClearML Server** (backend). If ``model_id`` is not specified,
then ``model_name`` must be specified.
:param update_task_design: Update the Task's design :param update_task_design: Update the Task's design
- ``True`` - ClearML copies the Task's model design from the input model. - ``True`` - ClearML copies the Task's model design from the input model.
@ -851,11 +811,14 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
- ``True`` - ClearML copies the Task's label enumeration from the input model. - ``True`` - ClearML copies the Task's label enumeration from the input model.
- ``False`` - ClearML does not copy the Task's label enumeration from the input model. - ``False`` - ClearML does not copy the Task's label enumeration from the input model.
:param name: Model section name to be stored on the Task (unrelated to the model object name itself)
Default: the the model weight filename is used (excluding file extension)
""" """
if model_id is None and not model_name: if model_id is None and not model_name:
raise ValueError('Expected one of [model_id, model_name]') raise ValueError('Expected one of [model_id, model_name]')
if model_name: if model_name and not model_id:
# Try getting the model by name. Limit to 10 results. # Try getting the model by name. Limit to 10 results.
res = self.send( res = self.send(
models.GetAllRequest( models.GetAllRequest(
@ -864,7 +827,7 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
page=0, page=0,
page_size=10, page_size=10,
order_by=['-created'], order_by=['-created'],
only_fields=['id', 'created'] only_fields=['id', 'created', 'uri']
) )
) )
model = get_single_result(entity='model', query=model_name, results=res.response.models, log=self.log) model = get_single_result(entity='model', query=model_name, results=res.response.models, log=self.log)
@ -876,15 +839,23 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
if not model.ready: if not model.ready:
# raise ValueError('Model %s is not published (not ready)' % model_id) # raise ValueError('Model %s is not published (not ready)' % model_id)
self.log.debug('Model %s [%s] is not published yet (not ready)' % (model_id, model.uri)) self.log.debug('Model %s [%s] is not published yet (not ready)' % (model_id, model.uri))
name = name or Path(model.uri).stem
else: else:
# clear the input model # clear the input model
model = None model = None
model_id = '' model_id = ''
name = name or 'Input Model'
with self._edit_lock: with self._edit_lock:
self.reload() self.reload()
# store model id # store model id
self.data.execution.model = model_id if Session.check_min_api_version("2.13"):
self.send(tasks.AddOrUpdateModelRequest(
task=self.id, name=name, model=model_id, type=tasks.ModelTypeEnum.input
))
else:
# backwards compatibility
self._set_task_property("execution.model", model_id, raise_on_error=False, log_on_error=False)
# Auto populate input field from model, if they are empty # Auto populate input field from model, if they are empty
if update_task_design and not self.data.execution.model_desc: if update_task_design and not self.data.execution.model_desc:
@ -1158,28 +1129,60 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
def _set_default_docker_image(self): def _set_default_docker_image(self):
# type: () -> () # type: () -> ()
if not DOCKER_IMAGE_ENV_VAR.exists(): if not DOCKER_IMAGE_ENV_VAR.exists() and not DOCKER_BASH_SETUP_ENV_VAR.exists():
return return
self.set_base_docker(DOCKER_IMAGE_ENV_VAR.get(default="")) self.set_base_docker(
docker_cmd=DOCKER_IMAGE_ENV_VAR.get(default=""),
docker_setup_bash_script=DOCKER_BASH_SETUP_ENV_VAR.get(default=""))
def set_base_docker(self, docker_cmd): def set_base_docker(self, docker_cmd, docker_arguments=None, docker_setup_bash_script=None):
# type: (str) -> () # type: (str, Optional[Union[str, Sequence[str]]], Optional[Union[str, Sequence[str]]]) -> ()
""" """
Set the base docker image for this experiment Set the base docker image for this experiment
If provided, this value will be used by clearml-agent to execute this experiment If provided, this value will be used by clearml-agent to execute this experiment
inside the provided docker image. inside the provided docker image.
When running remotely the call is ignored When running remotely the call is ignored
:param docker_cmd: docker container image (example: 'nvidia/cuda:11.1')
:param docker_arguments: docker execution parameters (example: '-e ENV=1')
:param docker_setup_bash_script: bash script to run at the
beginning of the docker before launching the Task itself. example: ['apt update', 'apt-get install -y gcc']
""" """
image = docker_cmd.split(' ')[0] if docker_cmd else ''
if not docker_arguments and docker_cmd:
arguments = docker_cmd.split(' ')[1:] if len(docker_cmd.split(' ')) > 1 else ''
else:
arguments = (docker_arguments if isinstance(docker_arguments, str) else ' '.join(docker_arguments)) \
if docker_arguments else ''
if docker_setup_bash_script:
setup_shell_script = docker_setup_bash_script \
if isinstance(docker_setup_bash_script, str) else '\n'.join(docker_setup_bash_script)
else:
setup_shell_script = None
with self._edit_lock: with self._edit_lock:
self.reload() self.reload()
execution = self.data.execution if Session.check_min_api_version("2.13"):
execution.docker_cmd = docker_cmd self.data.container = dict(image=image, arguments=arguments, setup_shell_script=setup_shell_script)
self._edit(execution=execution) else:
if setup_shell_script:
raise ValueError(
"Your ClearML-server does not support docker bash script feature, please upgrade.")
execution = self.data.execution
execution.docker_cmd = docker_cmd + (' {}'.format(arguments) if arguments else '')
self._edit(execution=execution)
def get_base_docker(self): def get_base_docker(self):
# type: () -> str # type: () -> str
"""Get the base Docker command (image) that is set for this experiment.""" """Get the base Docker command (image) that is set for this experiment."""
return self._get_task_property('execution.docker_cmd', raise_on_error=False, log_on_error=False) if Session.check_min_api_version("2.13"):
# backwards compatibility
container = self._get_task_property(
"container", raise_on_error=False, log_on_error=False, default={})
return (container.get('image', '') +
(' {}'.format(container['arguments']) if container.get('arguments', '') else '')) or None
else:
return self._get_task_property("execution.docker_cmd", raise_on_error=False, log_on_error=False)
def set_artifacts(self, artifacts_list=None): def set_artifacts(self, artifacts_list=None):
# type: (Sequence[tasks.Artifact]) -> () # type: (Sequence[tasks.Artifact]) -> ()
@ -1248,11 +1251,6 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
# noinspection PyProtectedMember # noinspection PyProtectedMember
return Model._unwrap_design(design) return Model._unwrap_design(design)
def set_output_model_id(self, model_id):
# type: (str) -> ()
self.data.output.model = str(model_id)
self._edit(output=self.data.output)
def get_random_seed(self): def get_random_seed(self):
# type: () -> int # type: () -> int
# fixed seed for the time being # fixed seed for the time being
@ -1639,48 +1637,6 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
""" """
cls._force_use_pip_freeze = bool(force) cls._force_use_pip_freeze = bool(force)
def _get_models(self, model_type='output'):
# type: (str) -> Sequence[Model]
# model_type is either 'output' or 'input'
model_type = model_type.lower().strip()
assert model_type == 'output' or model_type == 'input'
if model_type == 'input':
regex = r'((?i)(Using model id: )(\w+)?)'
compiled = re.compile(regex)
ids = [i[-1] for i in re.findall(compiled, self.comment)] + (
[self.input_model_id] if self.input_model_id else [])
# remove duplicates and preserve order
ids = list(OrderedDict.fromkeys(ids))
from ...model import Model as TrainsModel
in_model = []
for i in ids:
m = TrainsModel(model_id=i)
# noinspection PyBroadException
try:
# make sure the model is is valid
# noinspection PyProtectedMember
m._get_model_data()
in_model.append(m)
except Exception:
pass
return in_model
else:
res = self.send(
models.GetAllRequest(
task=[self.id],
order_by=['created'],
only_fields=['id']
)
)
if not res.response.models:
return []
ids = [m.id for m in res.response.models] + ([self.output_model_id] if self.output_model_id else [])
# remove duplicates and preserve order
ids = list(OrderedDict.fromkeys(ids))
from ...model import Model as TrainsModel
return [TrainsModel(model_id=i) for i in ids]
def _get_default_report_storage_uri(self): def _get_default_report_storage_uri(self):
# type: () -> str # type: () -> str
if self._offline_mode: if self._offline_mode:
@ -1735,8 +1691,13 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
binary='', repository='', tag='', branch='', version_num='', entry_point='', binary='', repository='', tag='', branch='', version_num='', entry_point='',
working_dir='', requirements={}, diff='', working_dir='', requirements={}, diff='',
) )
if Session.check_min_api_version("2.13"):
self._data.models = tasks.TaskModels(input=[], output=[])
self._data.container = dict()
self._data.execution = tasks.Execution( self._data.execution = tasks.Execution(
artifacts=[], dataviews=[], model='', model_desc={}, model_labels={}, parameters={}, docker_cmd='') artifacts=[], dataviews=[], model='', model_desc={}, model_labels={}, parameters={}, docker_cmd='')
self._data.comment = str(comment) self._data.comment = str(comment)
self._storage_uri = None self._storage_uri = None
@ -1744,7 +1705,13 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
self._update_requirements('') self._update_requirements('')
if Session.check_min_api_version('2.9'): if Session.check_min_api_version('2.13'):
self._set_task_property("system_tags", system_tags)
self._edit(system_tags=self._data.system_tags, comment=self._data.comment,
script=self._data.script, execution=self._data.execution, output_dest='',
hyperparams=dict(), configuration=dict(),
container=self._data.container, models=self._data.models)
elif Session.check_min_api_version('2.9'):
self._set_task_property("system_tags", system_tags) self._set_task_property("system_tags", system_tags)
self._edit(system_tags=self._data.system_tags, comment=self._data.comment, self._edit(system_tags=self._data.system_tags, comment=self._data.comment,
script=self._data.script, execution=self._data.execution, output_dest='', script=self._data.script, execution=self._data.execution, output_dest='',

View File

@ -263,7 +263,6 @@ class WeightsFileHandler(object):
# ref_model = None # ref_model = None
# WeightsFileHandler._model_in_store_lookup[id(model)] = (trains_in_model, ref_model) # WeightsFileHandler._model_in_store_lookup[id(model)] = (trains_in_model, ref_model)
# todo: support multiple models for the same task
task.connect(trains_in_model) task.connect(trains_in_model)
# if we are running remotely we should deserialize the object # if we are running remotely we should deserialize the object
# because someone might have changed the config_dict # because someone might have changed the config_dict
@ -332,7 +331,7 @@ class WeightsFileHandler(object):
# check if we have output storage, and generate list of files to upload # check if we have output storage, and generate list of files to upload
if Path(model_info.local_model_path).is_dir(): if Path(model_info.local_model_path).is_dir():
files = [str(f) for f in Path(model_info.local_model_path).rglob('*') if f.is_file()] files = [str(f) for f in Path(model_info.local_model_path).rglob('*')]
elif singlefile: elif singlefile:
files = [str(Path(model_info.local_model_path).absolute())] files = [str(Path(model_info.local_model_path).absolute())]
else: else:
@ -394,7 +393,8 @@ class WeightsFileHandler(object):
task=task, task=task,
config_dict=config_obj if isinstance(config_obj, dict) else None, config_dict=config_obj if isinstance(config_obj, dict) else None,
config_text=config_obj if isinstance(config_obj, str) else None, config_text=config_obj if isinstance(config_obj, str) else None,
name=(task.name + ' - ' + model_name) if model_name else None, name=None if in_model_id else '{} - {}'.format(
task.name, model_name or Path(model_info.local_model_path).stem),
label_enumeration=task.get_labels_enumeration(), label_enumeration=task.get_labels_enumeration(),
framework=framework, framework=framework,
base_model_id=in_model_id base_model_id=in_model_id

View File

@ -9,6 +9,7 @@ DEFAULT_CACHE_DIR = str(Path(tempfile.gettempdir()) / "clearml_cache")
TASK_ID_ENV_VAR = EnvEntry("CLEARML_TASK_ID", "TRAINS_TASK_ID") TASK_ID_ENV_VAR = EnvEntry("CLEARML_TASK_ID", "TRAINS_TASK_ID")
DOCKER_IMAGE_ENV_VAR = EnvEntry("CLEARML_DOCKER_IMAGE", "TRAINS_DOCKER_IMAGE") DOCKER_IMAGE_ENV_VAR = EnvEntry("CLEARML_DOCKER_IMAGE", "TRAINS_DOCKER_IMAGE")
DOCKER_BASH_SETUP_ENV_VAR = EnvEntry("CLEARML_DOCKER_BASH_SCRIPT")
LOG_TO_BACKEND_ENV_VAR = EnvEntry("CLEARML_LOG_TASK_TO_BACKEND", "TRAINS_LOG_TASK_TO_BACKEND", type=bool) LOG_TO_BACKEND_ENV_VAR = EnvEntry("CLEARML_LOG_TASK_TO_BACKEND", "TRAINS_LOG_TASK_TO_BACKEND", type=bool)
NODE_ID_ENV_VAR = EnvEntry("CLEARML_NODE_ID", "TRAINS_NODE_ID", type=int) NODE_ID_ENV_VAR = EnvEntry("CLEARML_NODE_ID", "TRAINS_NODE_ID", type=int)
PROC_MASTER_ID_ENV_VAR = EnvEntry("CLEARML_PROC_MASTER_ID", "TRAINS_PROC_MASTER_ID", type=str) PROC_MASTER_ID_ENV_VAR = EnvEntry("CLEARML_PROC_MASTER_ID", "TRAINS_PROC_MASTER_ID", type=str)

View File

@ -17,6 +17,7 @@ from .backend_interface.util import validate_dict, get_single_result, mutually_e
from .debugging.log import get_logger from .debugging.log import get_logger
from .storage.cache import CacheManager from .storage.cache import CacheManager
from .storage.helper import StorageHelper from .storage.helper import StorageHelper
from .storage.util import get_common_path
from .utilities.enum import Options from .utilities.enum import Options
from .backend_interface import Task as _Task from .backend_interface import Task as _Task
from .backend_interface.model import create_dummy_model, Model as _Model from .backend_interface.model import create_dummy_model, Model as _Model
@ -791,12 +792,12 @@ class InputModel(Model):
:param only_published: If True filter out non-published (draft) models :param only_published: If True filter out non-published (draft) models
""" """
if not model_id: if not model_id:
models = self.query_models( found_models = self.query_models(
project_name=project, model_name=name, tags=tags, only_published=only_published) project_name=project, model_name=name, tags=tags, only_published=only_published)
if not models: if not found_models:
raise ValueError("Could not locate model with project={} name={} tags={} published={}".format( raise ValueError("Could not locate model with project={} name={} tags={} published={}".format(
project, name, tags, only_published)) project, name, tags, only_published))
model_id = models[0].id model_id = found_models[0].id
super(InputModel, self).__init__(model_id) super(InputModel, self).__init__(model_id)
@property @property
@ -804,8 +805,8 @@ class InputModel(Model):
# type: () -> str # type: () -> str
return self._base_model_id return self._base_model_id
def connect(self, task): def connect(self, task, name=None):
# type: (Task) -> None # type: (Task, Optional[str]) -> None
""" """
Connect the current model to a Task object, if the model is preexisting. Preexisting models include: Connect the current model to a Task object, if the model is preexisting. Preexisting models include:
@ -823,32 +824,37 @@ class InputModel(Model):
to execute in a worker. to execute in a worker.
:param object task: A Task object. :param object task: A Task object.
:param str name: The model name to be stored on the Task
(default the filename, of the model weights, without the file extension)
""" """
self._set_task(task) self._set_task(task)
model_id = None
# noinspection PyProtectedMember # noinspection PyProtectedMember
if running_remotely() and task.input_model and (task.is_main_task() or task._is_remote_main_task()): if running_remotely() and (task.is_main_task() or task._is_remote_main_task()):
self._base_model = task.input_model input_models = task.input_models_id
self._base_model_id = task.input_model.id try:
else: model_id = next(m_id for m_name, m_id in input_models if m_name == (name or 'Input Model'))
self._base_model_id = model_id
self._base_model = InputModel(model_id=model_id)._get_base_model()
except StopIteration:
model_id = None
if not model_id:
# we should set the task input model to point to us # we should set the task input model to point to us
model = self._get_base_model() model = self._get_base_model()
# try to store the input model id, if it is not empty # try to store the input model id, if it is not empty
# (Empty Should not happen)
if model.id != self._EMPTY_MODEL_ID: if model.id != self._EMPTY_MODEL_ID:
task.set_input_model(model_id=model.id) task.set_input_model(model_id=model.id, name=name)
# only copy the model design if the task has no design to begin with # only copy the model design if the task has no design to begin with
# noinspection PyProtectedMember # noinspection PyProtectedMember
if not self._task._get_model_config_text(): if not self._task._get_model_config_text() and model.model_design:
# noinspection PyProtectedMember # noinspection PyProtectedMember
task._set_model_config(config_text=model.model_design) task._set_model_config(config_text=model.model_design)
if not self._task.get_labels_enumeration(): if not self._task.get_labels_enumeration() and model.data.labels:
task.set_model_label_enumeration(model.data.labels) task.set_model_label_enumeration(model.data.labels)
# If there was an output model connected, it may need to be updated by
# the newly connected input model
# noinspection PyProtectedMember
self.task._reconnect_output_model()
class OutputModel(BaseModel): class OutputModel(BaseModel):
""" """
@ -871,6 +877,8 @@ class OutputModel(BaseModel):
label enumeration using the **ClearML Web-App**. label enumeration using the **ClearML Web-App**.
""" """
_default_output_uri = None
@property @property
def published(self): def published(self):
# type: () -> bool # type: () -> bool
@ -1026,6 +1034,8 @@ class OutputModel(BaseModel):
self._model_local_filename = None self._model_local_filename = None
self._last_uploaded_url = None self._last_uploaded_url = None
self._base_model = None self._base_model = None
self._base_model_id = None
self._task_connect_name = None
# noinspection PyProtectedMember # noinspection PyProtectedMember
self._floating_data = create_dummy_model( self._floating_data = create_dummy_model(
design=_Model._wrap_design(config_text), design=_Model._wrap_design(config_text),
@ -1037,34 +1047,37 @@ class OutputModel(BaseModel):
framework=framework, framework=framework,
upload_storage_uri=task.output_uri, upload_storage_uri=task.output_uri,
) )
if base_model_id: # If we have no real model ID, we are done
# noinspection PyBroadException if not base_model_id:
try: return
# noinspection PyProtectedMember
_base_model = self._task._get_output_model(model_id=base_model_id)
_base_model.update(
labels=self._floating_data.labels,
design=self._floating_data.design,
task_id=self._task.id,
project_id=self._task.project,
name=self._floating_data.name or self._task.name,
comment=('{}\n{}'.format(_base_model.comment, self._floating_data.comment)
if (_base_model.comment and self._floating_data.comment and
self._floating_data.comment not in _base_model.comment)
else (_base_model.comment or self._floating_data.comment)),
tags=self._floating_data.tags,
framework=self._floating_data.framework,
upload_storage_uri=self._floating_data.upload_storage_uri
)
self._base_model = _base_model
self._floating_data = None
self._base_model.update_for_task(task_id=self._task.id, override_model_id=self.id)
except Exception:
pass
self.connect(task)
def connect(self, task): # noinspection PyBroadException
# type: (Task) -> None try:
# noinspection PyProtectedMember
_base_model = self._task._get_output_model(model_id=base_model_id)
_base_model.update(
labels=self._floating_data.labels,
design=self._floating_data.design,
task_id=self._task.id,
project_id=self._task.project,
name=self._floating_data.name or self._task.name,
comment=('{}\n{}'.format(_base_model.comment, self._floating_data.comment)
if (_base_model.comment and self._floating_data.comment and
self._floating_data.comment not in _base_model.comment)
else (_base_model.comment or self._floating_data.comment)),
tags=self._floating_data.tags,
framework=self._floating_data.framework,
upload_storage_uri=self._floating_data.upload_storage_uri
)
self._base_model = _base_model
self._floating_data = None
name = self._task_connect_name or Path(_base_model.uri).stem
except Exception:
pass
self.connect(task, name=name)
def connect(self, task, name=None):
# type: (Task, Optional[str]) -> None
""" """
Connect the current model to a Task object, if the model is a preexisting model. Preexisting models include: Connect the current model to a Task object, if the model is a preexisting model. Preexisting models include:
@ -1073,46 +1086,31 @@ class OutputModel(BaseModel):
- Models from another source, such as frameworks like TensorFlow. - Models from another source, such as frameworks like TensorFlow.
:param object task: A Task object. :param object task: A Task object.
:param str name: The model name as it would appear on the Task object.
The model object itself can have a different name,
this is designed to support multiple models used/created by a single Task.
Use examples would be GANs or model ensemble
""" """
if self._task != task: if self._task != task:
raise ValueError('Can only connect preexisting model to task, but this is a fresh model') raise ValueError('Can only connect preexisting model to task, but this is a fresh model')
# noinspection PyProtectedMember if name:
if running_remotely() and (task.is_main_task() or task._is_remote_main_task()): self._task_connect_name = name
if self._floating_data:
# noinspection PyProtectedMember
self._floating_data.design = _Model._wrap_design(self._task._get_model_config_text()) or \
self._floating_data.design
self._floating_data.labels = self._task.get_labels_enumeration() or \
self._floating_data.labels
elif self._base_model:
# noinspection PyProtectedMember
self._base_model.update(design=_Model._wrap_design(self._task._get_model_config_text()) or
self._base_model.design)
self._base_model.update(labels=self._task.get_labels_enumeration() or self._base_model.labels)
elif self._floating_data is not None: # we should set the task input model to point to us
# we copy configuration / labels if they exist, obviously someone wants them as the output base model model = self._get_base_model()
# only copy the model design if the task has no design to begin with
# noinspection PyProtectedMember
if not self._task._get_model_config_text():
# noinspection PyProtectedMember # noinspection PyProtectedMember
design = _Model._unwrap_design(self._floating_data.design) task._set_model_config(config_text=model.model_design)
if design: if not self._task.get_labels_enumeration():
# noinspection PyProtectedMember task.set_model_label_enumeration(model.data.labels)
if not task._get_model_config_text():
if not Session.check_min_api_version('2.9'):
design = self._floating_data.design
# noinspection PyProtectedMember
task._set_model_config(config_text=design)
else:
# noinspection PyProtectedMember
self._floating_data.design = _Model._wrap_design(self._task._get_model_config_text())
if self._floating_data.labels: if self._base_model:
task.set_model_label_enumeration(self._floating_data.labels) self._base_model.update_for_task(
else: task_id=self._task.id, model_id=self.id, type_="output", name=self._task_connect_name)
self._floating_data.labels = self._task.get_labels_enumeration()
# noinspection PyProtectedMember
self.task._save_output_model(self)
def set_upload_destination(self, uri): def set_upload_destination(self, uri):
# type: (str) -> None # type: (str) -> None
@ -1235,7 +1233,9 @@ class OutputModel(BaseModel):
self._model_local_filename = weights_filename self._model_local_filename = weights_filename
# make sure the created model is updated: # make sure the created model is updated:
model = self._get_force_base_model() out_model_file_name = target_filename or weights_filename or register_uri
name = Path(out_model_file_name).stem if out_model_file_name else (self._task_connect_name or "Output Model")
model = self._get_force_base_model(task_model_entry=name)
if not model: if not model:
raise ValueError('Failed creating internal output model') raise ValueError('Failed creating internal output model')
@ -1299,10 +1299,6 @@ class OutputModel(BaseModel):
if is_package: if is_package:
self._set_package_tag() self._set_package_tag()
# make sure that if we are in dev move we report that we are training (not debugging)
# noinspection PyProtectedMember
self._task._output_model_updated()
return output_uri return output_uri
def update_weights_package( def update_weights_package(
@ -1347,13 +1343,17 @@ class OutputModel(BaseModel):
if not weights_filenames: if not weights_filenames:
weights_filenames = list(map(six.text_type, Path(weights_path).rglob('*'))) weights_filenames = list(map(six.text_type, Path(weights_path).rglob('*')))
elif weights_filenames and len(weights_filenames) > 1:
weights_path = get_common_path(weights_filenames)
# create packed model from all the files # create packed model from all the files
fd, zip_file = mkstemp(prefix='model_package.', suffix='.zip') fd, zip_file = mkstemp(prefix='model_package.', suffix='.zip')
try: try:
with zipfile.ZipFile(zip_file, 'w', allowZip64=True, compression=zipfile.ZIP_STORED) as zf: with zipfile.ZipFile(zip_file, 'w', allowZip64=True, compression=zipfile.ZIP_STORED) as zf:
for filename in weights_filenames: for filename in weights_filenames:
zf.write(filename, arcname=Path(filename).name) relative_file_name = Path(filename).name if not weights_path else \
Path(filename).absolute().relative_to(Path(weights_path).absolute()).as_posix()
zf.write(filename, arcname=relative_file_name)
finally: finally:
os.close(fd) os.close(fd)
@ -1471,24 +1471,38 @@ class OutputModel(BaseModel):
""" """
_Model.wait_for_results(timeout=timeout, max_num_uploads=max_num_uploads) _Model.wait_for_results(timeout=timeout, max_num_uploads=max_num_uploads)
def _get_force_base_model(self): @classmethod
def set_default_upload_uri(cls, output_uri):
# type: (Optional[str]) -> None
"""
Set the default upload uri for all OutputModels
:param output_uri: URL for uploading models. examples:
https://demofiles.demo.clear.ml, s3://bucket/, gs://bucket/, azure://bucket/, file:///mnt/shared/nfs
"""
cls._default_output_uri = str(output_uri) if output_uri else None
def _get_force_base_model(self, model_name=None, task_model_entry=None):
if self._base_model: if self._base_model:
return self._base_model return self._base_model
# create a new model from the task # create a new model from the task
self._base_model = self._task.create_output_model() # noinspection PyProtectedMember
self._base_model = self._task._get_output_model(model_id=None)
# update the model from the task inputs # update the model from the task inputs
labels = self._task.get_labels_enumeration() labels = self._task.get_labels_enumeration()
# noinspection PyProtectedMember # noinspection PyProtectedMember
config_text = self._task._get_model_config_text() config_text = self._task._get_model_config_text()
parent = self._task.output_model_id or self._task.input_model_id model_name = model_name or self._floating_data.name or self._task.name
task_model_entry = task_model_entry or self._task_connect_name or Path(self._get_model_data().uri).stem
parent = self._task.input_models_id.get(task_model_entry)
self._base_model.update( self._base_model.update(
labels=self._floating_data.labels or labels, labels=self._floating_data.labels or labels,
design=self._floating_data.design or config_text, design=self._floating_data.design or config_text,
task_id=self._task.id, task_id=self._task.id,
project_id=self._task.project, project_id=self._task.project,
parent_id=parent, parent_id=parent,
name=self._floating_data.name or self._task.name, name=model_name,
comment=self._floating_data.comment, comment=self._floating_data.comment,
tags=self._floating_data.tags, tags=self._floating_data.tags,
framework=self._floating_data.framework, framework=self._floating_data.framework,
@ -1499,11 +1513,13 @@ class OutputModel(BaseModel):
self._floating_data = None self._floating_data = None
# now we have to update the creator task so it points to us # now we have to update the creator task so it points to us
if self._task.status not in (self._task.TaskStatusEnum.created, self._task.TaskStatusEnum.in_progress): if str(self._task.status) not in (
str(self._task.TaskStatusEnum.created), str(self._task.TaskStatusEnum.in_progress)):
self._log.warning('Could not update last created model in Task {}, ' self._log.warning('Could not update last created model in Task {}, '
'Task status \'{}\' cannot be updated'.format(self._task.id, self._task.status)) 'Task status \'{}\' cannot be updated'.format(self._task.id, self._task.status))
else: else:
self._base_model.update_for_task(task_id=self._task.id, override_model_id=self.id) self._base_model.update_for_task(
task_id=self._task.id, model_id=self.id, type_="output", name=task_model_entry)
return self._base_model return self._base_model

View File

@ -1,7 +1,8 @@
import hashlib import hashlib
import re import re
import sys import sys
from typing import Optional, Union from typing import Optional, Union, Sequence
from pathlib2 import Path
from six.moves.urllib.parse import quote, urlparse, urlunparse from six.moves.urllib.parse import quote, urlparse, urlunparse
import six import six
@ -214,3 +215,44 @@ def parse_size(size, binary=False):
return int(tokens[0] * k) return int(tokens[0] * k)
raise ValueError("Failed to parse size! (input {} was tokenized as {})".format(size, tokens)) raise ValueError("Failed to parse size! (input {} was tokenized as {})".format(size, tokens))
def get_common_path(list_of_files):
# type: (Sequence[Union[str, Path]]) -> Optional[str]
"""
Return the common path of a list of files
:param list_of_files: list of files (str or Path objects)
:return: Common path string (always absolute) or None if common path could not be found
"""
if not list_of_files:
return None
# a single file has its parent as common path
if len(list_of_files) == 1:
return Path(list_of_files[0]).absolute().parent.as_posix()
# find common path to support folder structure inside zip
common_path_parts = Path(list_of_files[0]).absolute().parts
for f in list_of_files:
f_parts = Path(f).absolute().parts
num_p = min(len(f_parts), len(common_path_parts))
if f_parts[:num_p] == common_path_parts[:num_p]:
common_path_parts = common_path_parts[:num_p]
continue
num_p = min(
[i for i, (a, b) in enumerate(zip(common_path_parts[:num_p], f_parts[:num_p])) if a != b] or [-1])
# no common path, break
if num_p < 0:
common_path_parts = []
break
# update common path
common_path_parts = common_path_parts[:num_p]
if common_path_parts:
common_path = Path()
for f in common_path_parts:
common_path /= f
return common_path.as_posix()
return None

View File

@ -33,6 +33,7 @@ from .backend_interface.task import Task as _Task
from .backend_interface.task.log import TaskHandler from .backend_interface.task.log import TaskHandler
from .backend_interface.task.development.worker import DevWorker from .backend_interface.task.development.worker import DevWorker
from .backend_interface.task.repo import ScriptInfo from .backend_interface.task.repo import ScriptInfo
from .backend_interface.task.models import TaskModels
from .backend_interface.util import get_single_result, exact_match_regex, make_message, mutually_exclusive from .backend_interface.util import get_single_result, exact_match_regex, make_message, mutually_exclusive
from .binding.absl_bind import PatchAbsl from .binding.absl_bind import PatchAbsl
from .binding.artifacts import Artifacts, Artifact from .binding.artifacts import Artifacts, Artifact
@ -161,7 +162,6 @@ class Task(_Task):
super(Task, self).__init__(**kwargs) super(Task, self).__init__(**kwargs)
self._arguments = _Arguments(self) self._arguments = _Arguments(self)
self._logger = None self._logger = None
self._last_input_model_id = None
self._connected_output_model = None self._connected_output_model = None
self._dev_worker = None self._dev_worker = None
self._connected_parameter_type = None self._connected_parameter_type = None
@ -596,7 +596,7 @@ class Task(_Task):
logger.report_text('ClearML results page: {}'.format(task.get_output_log_web_page())) logger.report_text('ClearML results page: {}'.format(task.get_output_log_web_page()))
# Make sure we start the dev worker if required, otherwise it will only be started when we write # Make sure we start the dev worker if required, otherwise it will only be started when we write
# something to the log. # something to the log.
task._dev_mode_task_start() task._dev_mode_setup_worker()
if (not task._reporter or not task._reporter.is_alive()) and \ if (not task._reporter or not task._reporter.is_alive()) and \
is_sub_process_task_id and not cls._report_subprocess_enabled: is_sub_process_task_id and not cls._report_subprocess_enabled:
@ -809,11 +809,35 @@ class Task(_Task):
@property @property
def models(self): def models(self):
# type: () -> Dict[str, Sequence[Model]] # type: () -> Mapping[str, Sequence[Model]]
""" """
Read-only dictionary of the Task's loaded/stored models Read-only dictionary of the Task's loaded/stored models.
:return: A dictionary of models loaded/stored {'input': list(Model), 'output': list(Model)}. :return: A dictionary-like object with "input"/"output" keys and input/output properties, pointing to a
list-like object containing of Model objects. Each list-like object also acts as a dictionary, mapping
model name to a appropriate model instance.
Get input/output models:
.. code-block:: py
task.models.input
task.models["input"]
task.models.output
task.models["output"]
Get the last output model:
.. code-block:: py
task.models.output[-1]
Get a model by name:
.. code-block:: py
task.models.output["model name"]
""" """
return self.get_models() return self.get_models()
@ -1051,7 +1075,7 @@ class Task(_Task):
) )
multi_config_support = Session.check_min_api_version('2.9') multi_config_support = Session.check_min_api_version('2.9')
if multi_config_support and not name: if multi_config_support and not name and not isinstance(mutable, (OutputModel, InputModel)):
name = self._default_configuration_section_name name = self._default_configuration_section_name
if not multi_config_support and name and name != self._default_configuration_section_name: if not multi_config_support and name and name != self._default_configuration_section_name:
@ -1498,14 +1522,16 @@ class Task(_Task):
auto_pickle=auto_pickle, preview=preview, wait_on_upload=wait_on_upload) auto_pickle=auto_pickle, preview=preview, wait_on_upload=wait_on_upload)
def get_models(self): def get_models(self):
# type: () -> Dict[str, Sequence[Model]] # type: () -> Mapping[str, Sequence[Model]]
""" """
Return a dictionary with {'input': [], 'output': []} loaded/stored models of the current Task Return a dictionary with {'input': [], 'output': []} loaded/stored models of the current Task
Input models are files loaded in the task, either manually or automatically logged Input models are files loaded in the task, either manually or automatically logged
Output models are files stored in the task, either manually or automatically logged Output models are files stored in the task, either manually or automatically logged
Automatically logged frameworks are for example: TensorFlow, Keras, PyTorch, ScikitLearn(joblib) etc. Automatically logged frameworks are for example: TensorFlow, Keras, PyTorch, ScikitLearn(joblib) etc.
:return: A dictionary with keys input/output, each is list of Model objects. :return: A dictionary-like object with "input"/"output" keys and input/output properties, pointing to a
list-like object containing of Model objects. Each list-like object also acts as a dictionary, mapping
model name to a appropriate model instance.
Example: Example:
@ -1514,9 +1540,7 @@ class Task(_Task):
{'input': [clearml.Model()], 'output': [clearml.Model()]} {'input': [clearml.Model()], 'output': [clearml.Model()]}
""" """
task_models = {'input': self._get_models(model_type='input'), return TaskModels(self)
'output': self._get_models(model_type='output')}
return task_models
def is_current_task(self): def is_current_task(self):
# type: () -> bool # type: () -> bool
@ -1795,17 +1819,27 @@ class Task(_Task):
return self._hyper_params_manager.delete_hyper_params(*iterables) return self._hyper_params_manager.delete_hyper_params(*iterables)
def set_base_docker(self, docker_cmd): def set_base_docker(self, docker_cmd, docker_arguments=None, docker_setup_bash_script=None):
# type: (str) -> () # type: (str, Optional[Union[str, Sequence[str]]], Optional[Union[str, Sequence[str]]]) -> ()
""" """
Set the base docker image for this experiment Set the base docker image for this experiment
If provided, this value will be used by clearml-agent to execute this experiment If provided, this value will be used by clearml-agent to execute this experiment
inside the provided docker image. inside the provided docker image.
When running remotely the call is ignored
:param docker_cmd: docker container image (example: 'nvidia/cuda:11.1')
:param docker_arguments: docker execution parameters (example: '-e ENV=1')
:param docker_setup_bash_script: bash script to run at the
beginning of the docker before launching the Task itself. example: ['apt update', 'apt-get install -y gcc']
""" """
if not self.running_locally() and self.is_main_task(): if not self.running_locally() and self.is_main_task():
return return
super(Task, self).set_base_docker(docker_cmd) super(Task, self).set_base_docker(
docker_cmd=docker_cmd,
docker_arguments=docker_arguments,
docker_setup_bash_script=docker_setup_bash_script
)
def set_resource_monitor_iteration_timeout(self, seconds_from_start=1800): def set_resource_monitor_iteration_timeout(self, seconds_from_start=1800):
# type: (float) -> bool # type: (float) -> bool
@ -2418,7 +2452,7 @@ class Task(_Task):
if hasattr(task.data.execution, 'artifacts') else None if hasattr(task.data.execution, 'artifacts') else None
if ((str(task._status) in ( if ((str(task._status) in (
str(tasks.TaskStatusEnum.published), str(tasks.TaskStatusEnum.closed))) str(tasks.TaskStatusEnum.published), str(tasks.TaskStatusEnum.closed)))
or task.output_model_id or (cls.archived_tag in task_tags) or task.output_models_id or (cls.archived_tag in task_tags)
or (cls._development_tag not in task_tags) or (cls._development_tag not in task_tags)
or task_artifacts): or task_artifacts):
# If the task is published or closed, we shouldn't reset it so we can't use it in dev mode # If the task is published or closed, we shouldn't reset it so we can't use it in dev mode
@ -2558,25 +2592,27 @@ class Task(_Task):
def _connect_output_model(self, model, name=None): def _connect_output_model(self, model, name=None):
assert isinstance(model, OutputModel) assert isinstance(model, OutputModel)
model.connect(self) model.connect(self, name=name)
return model return model
def _save_output_model(self, model): def _save_output_model(self, model):
""" """
Save a reference to the connected output model. Deprecated: Save a reference to the connected output model.
:param model: The connected output model :param model: The connected output model
""" """
# deprecated
self._connected_output_model = model self._connected_output_model = model
def _reconnect_output_model(self): def _reconnect_output_model(self):
""" """
If there is a saved connected output model, connect it again. Deprecated: If there is a saved connected output model, connect it again.
This is needed if the input model is connected after the output model This is needed if the input model is connected after the output model
is connected, an then we will have to get the model design from the is connected, an then we will have to get the model design from the
input model by reconnecting. input model by reconnecting.
""" """
# Deprecated:
if self._connected_output_model: if self._connected_output_model:
self.connect(self._connected_output_model) self.connect(self._connected_output_model)
@ -2591,33 +2627,10 @@ class Task(_Task):
comment += '\n' comment += '\n'
comment += 'Using model id: {}'.format(model.id) comment += 'Using model id: {}'.format(model.id)
self.set_comment(comment) self.set_comment(comment)
if self._last_input_model_id and self._last_input_model_id != model.id:
self.log.info('Task connect, second input model is not supported, adding into comment section') model.connect(self, name)
return
self._last_input_model_id = model.id
model.connect(self)
return model return model
def _try_set_connected_parameter_type(self, option):
# """ Raise an error if current value is not None and not equal to the provided option value """
# value = self._connected_parameter_type
# if not value or value == option:
# self._connected_parameter_type = option
# return option
#
# def title(option):
# return " ".join(map(str.capitalize, option.split("_")))
#
# raise ValueError(
# "Task already connected to {}. "
# "Task can be connected to only one the following argument options: {}".format(
# title(value),
# ' / '.join(map(title, self._ConnectedParametersType._options())))
# )
# added support for multiple type connections through _Arguments
return option
def _connect_argparse(self, parser, args=None, namespace=None, parsed_args=None, name=None): def _connect_argparse(self, parser, args=None, namespace=None, parsed_args=None, name=None):
# do not allow argparser to connect to jupyter notebook # do not allow argparser to connect to jupyter notebook
# noinspection PyBroadException # noinspection PyBroadException
@ -2631,8 +2644,6 @@ class Task(_Task):
except Exception: except Exception:
pass pass
self._try_set_connected_parameter_type(self._ConnectedParametersType.argparse)
if self.is_main_task(): if self.is_main_task():
argparser_update_currenttask(self) argparser_update_currenttask(self)
@ -2672,8 +2683,6 @@ class Task(_Task):
config_dict.clear() config_dict.clear()
config_dict.update(nested_from_flat_dictionary(nested_dict, a_flat_dict)) config_dict.update(nested_from_flat_dictionary(nested_dict, a_flat_dict))
self._try_set_connected_parameter_type(self._ConnectedParametersType.dictionary)
if not running_remotely() or not (self.is_main_task() or self._is_remote_main_task()): if not running_remotely() or not (self.is_main_task() or self._is_remote_main_task()):
self._arguments.copy_from_dict(flatten_dictionary(dictionary), prefix=name) self._arguments.copy_from_dict(flatten_dictionary(dictionary), prefix=name)
dictionary = ProxyDictPostWrite(self, _update_args_dict, **dictionary) dictionary = ProxyDictPostWrite(self, _update_args_dict, **dictionary)
@ -2686,8 +2695,6 @@ class Task(_Task):
return dictionary return dictionary
def _connect_task_parameters(self, attr_class, name=None): def _connect_task_parameters(self, attr_class, name=None):
self._try_set_connected_parameter_type(self._ConnectedParametersType.task_parameters)
if running_remotely() and (self.is_main_task() or self._is_remote_main_task()): if running_remotely() and (self.is_main_task() or self._is_remote_main_task()):
parameters = self.get_parameters() parameters = self.get_parameters()
if not name: if not name:
@ -2726,18 +2733,6 @@ class Task(_Task):
if running_remotely(): if running_remotely():
super(Task, self)._validate(check_output_dest_credentials=False) super(Task, self)._validate(check_output_dest_credentials=False)
def _output_model_updated(self):
""" Called when a connected output model is updated """
if running_remotely() or not self.is_main_task():
return
# Make sure we know we've started, just in case we didn't so far
self._dev_mode_task_start(model_updated=True)
def _dev_mode_task_start(self, model_updated=False):
""" Called when we suspect the task has started running """
self._dev_mode_setup_worker(model_updated=model_updated)
def _dev_mode_stop_task(self, stop_reason, pid=None): def _dev_mode_stop_task(self, stop_reason, pid=None):
# make sure we do not get called (by a daemon thread) after at_exit # make sure we do not get called (by a daemon thread) after at_exit
if self._at_exit_called: if self._at_exit_called:
@ -2797,7 +2792,7 @@ class Task(_Task):
else: else:
kill_ourselves.terminate() kill_ourselves.terminate()
def _dev_mode_setup_worker(self, model_updated=False): def _dev_mode_setup_worker(self):
if running_remotely() or not self.is_main_task() or self._at_exit_called or self._offline_mode: if running_remotely() or not self.is_main_task() or self._at_exit_called or self._offline_mode:
return return
@ -2886,7 +2881,7 @@ class Task(_Task):
is_sub_process = self.__is_subprocess() is_sub_process = self.__is_subprocess()
if True:##not is_sub_process: if True: # not is_sub_process: # todo: remove IF
# noinspection PyBroadException # noinspection PyBroadException
try: try:
wait_for_uploads = True wait_for_uploads = True