Add Dataset.get_logger()

Fix using Dataset on current_task set Task type
This commit is contained in:
allegroai 2021-02-28 19:49:38 +02:00
parent f91645fdaf
commit ecc539ffb6

View File

@ -12,7 +12,7 @@ from zipfile import ZipFile, ZIP_DEFLATED
from attr import attrs, attrib from attr import attrs, attrib
from pathlib2 import Path from pathlib2 import Path
from .. import Task, StorageManager from .. import Task, StorageManager, Logger
from ..backend_api.session.client import APIClient from ..backend_api.session.client import APIClient
from ..backend_interface.task.development.worker import DevWorker from ..backend_interface.task.development.worker import DevWorker
from ..backend_interface.util import mutually_exclusive, exact_match_regex from ..backend_interface.util import mutually_exclusive, exact_match_regex
@ -65,8 +65,14 @@ class Dataset(object):
self._dataset_file_entries = {} # type: Dict[str, FileEntry] self._dataset_file_entries = {} # type: Dict[str, FileEntry]
# this will create a graph of all the dependencies we have, each entry lists it's own direct parents # this will create a graph of all the dependencies we have, each entry lists it's own direct parents
self._dependency_graph = {} # type: Dict[str, List[str]] self._dependency_graph = {} # type: Dict[str, List[str]]
self._created_task = False if task:
if not task: self._task_pinger = None
self._created_task = False
# If we are reusing the main current Task, make sure we set its type to data_processing
if str(task.task_type) != str(Task.TaskTypes.data_processing) and \
str(task.data.status) in ('created', 'in_progress'):
task.set_task_type(task_type=Task.TaskTypes.data_processing)
else:
self._created_task = True self._created_task = True
task = Task.create( task = Task.create(
project_name=dataset_project, task_name=dataset_name, task_type=Task.TaskTypes.data_processing) project_name=dataset_project, task_name=dataset_name, task_type=Task.TaskTypes.data_processing)
@ -85,12 +91,9 @@ class Dataset(object):
# noinspection PyProtectedMember # noinspection PyProtectedMember
task._edit(script=task.data.script) task._edit(script=task.data.script)
# if the task is running make sure we ping to the server so it will not be aborted by a watchdog # if the task is running make sure we ping to the server so it will not be aborted by a watchdog
if self._created_task and task.status in ('created', 'in_progress'):
self._task_pinger = DevWorker() self._task_pinger = DevWorker()
self._task_pinger.register(task, stop_signal_support=False) self._task_pinger.register(task, stop_signal_support=False)
else:
self._task_pinger = None
# store current dataset Task # store current dataset Task
self._task = task self._task = task
@ -808,6 +811,15 @@ class Dataset(object):
return instance return instance
def get_logger(self):
# type: () -> Logger
"""
Return a Logger object for the Dataset, allowing users to report statistics metrics
and debug samples on the Dataset itself
:return: Logger object
"""
return self._task.get_logger()
@classmethod @classmethod
def squash(cls, dataset_name, dataset_ids=None, dataset_project_name_pairs=None, output_url=None): def squash(cls, dataset_name, dataset_ids=None, dataset_project_name_pairs=None, output_url=None):
# type: (str, Optional[Sequence[Union[str, Dataset]]],Optional[Sequence[(str, str)]], Optional[str]) -> Dataset # type: (str, Optional[Sequence[Union[str, Dataset]]],Optional[Sequence[(str, str)]], Optional[str]) -> Dataset