Add Dataset.get_logger()

Fix using Dataset on current_task set Task type
This commit is contained in:
allegroai 2021-02-28 19:49:38 +02:00
parent f91645fdaf
commit ecc539ffb6

View File

@ -12,7 +12,7 @@ from zipfile import ZipFile, ZIP_DEFLATED
from attr import attrs, attrib
from pathlib2 import Path
from .. import Task, StorageManager
from .. import Task, StorageManager, Logger
from ..backend_api.session.client import APIClient
from ..backend_interface.task.development.worker import DevWorker
from ..backend_interface.util import mutually_exclusive, exact_match_regex
@ -65,8 +65,14 @@ class Dataset(object):
self._dataset_file_entries = {} # type: Dict[str, FileEntry]
# this will create a graph of all the dependencies we have, each entry lists it's own direct parents
self._dependency_graph = {} # type: Dict[str, List[str]]
self._created_task = False
if not task:
if task:
self._task_pinger = None
self._created_task = False
# If we are reusing the main current Task, make sure we set its type to data_processing
if str(task.task_type) != str(Task.TaskTypes.data_processing) and \
str(task.data.status) in ('created', 'in_progress'):
task.set_task_type(task_type=Task.TaskTypes.data_processing)
else:
self._created_task = True
task = Task.create(
project_name=dataset_project, task_name=dataset_name, task_type=Task.TaskTypes.data_processing)
@ -85,12 +91,9 @@ class Dataset(object):
# noinspection PyProtectedMember
task._edit(script=task.data.script)
# if the task is running make sure we ping to the server so it will not be aborted by a watchdog
if self._created_task and task.status in ('created', 'in_progress'):
# if the task is running make sure we ping to the server so it will not be aborted by a watchdog
self._task_pinger = DevWorker()
self._task_pinger.register(task, stop_signal_support=False)
else:
self._task_pinger = None
# store current dataset Task
self._task = task
@ -808,6 +811,15 @@ class Dataset(object):
return instance
def get_logger(self):
# type: () -> Logger
"""
Return a Logger object for the Dataset, allowing users to report statistics metrics
and debug samples on the Dataset itself
:return: Logger object
"""
return self._task.get_logger()
@classmethod
def squash(cls, dataset_name, dataset_ids=None, dataset_project_name_pairs=None, output_url=None):
# type: (str, Optional[Sequence[Union[str, Dataset]]],Optional[Sequence[(str, str)]], Optional[str]) -> Dataset