From 30eaed79ea44bb64e9d5f8d35a70e27ae4845cae Mon Sep 17 00:00:00 2001 From: allegroai <> Date: Mon, 6 Jan 2020 17:20:15 +0200 Subject: [PATCH] Add warning when automatic argument parser binding cannot be turned off --- trains/task.py | 36 +++++++++++++++++++++++------------- 1 file changed, 23 insertions(+), 13 deletions(-) diff --git a/trains/task.py b/trains/task.py index 5716ce49..5f157afa 100644 --- a/trains/task.py +++ b/trains/task.py @@ -5,6 +5,7 @@ import sys import threading import time from argparse import ArgumentParser +from collections import Callable from tempfile import mkstemp try: @@ -16,8 +17,8 @@ from typing import Optional import psutil import six +from pathlib2 import Path -from .binding.joblib_bind import PatchedJoblib from .backend_api.services import tasks, projects, queues from .backend_api.session.session import Session from .backend_interface.model import Model as BackendModel @@ -26,6 +27,14 @@ from .backend_interface.task.args import _Arguments from .backend_interface.task.development.worker import DevWorker from .backend_interface.task.repo import ScriptInfo from .backend_interface.util import get_single_result, exact_match_regex, make_message +from .binding.absl_bind import PatchAbsl +from .binding.artifacts import Artifacts, Artifact +from .binding.environ_bind import EnvironmentBind, PatchOsFork +from .binding.frameworks.pytorch_bind import PatchPyTorchModelIO +from .binding.frameworks.tensorflow_bind import TensorflowBinding +from .binding.frameworks.xgboost_bind import PatchXGBoostModelIO +from .binding.joblib_bind import PatchedJoblib +from .binding.matplotlib_bind import PatchedMatplotlib from .config import config, PROC_MASTER_ID_ENV_VAR, DEV_TASK_NO_REUSE from .config import running_remotely, get_remote_task_id from .config.cache import SessionCache @@ -34,20 +43,13 @@ from .errors import UsageError from .logger import Logger from .model import InputModel, OutputModel, ARCHIVED_TAG from .task_parameters import TaskParameters -from .binding.artifacts import Artifacts, Artifact -from .binding.environ_bind import EnvironmentBind, PatchOsFork -from .binding.absl_bind import PatchAbsl from .utilities.args import argparser_parseargs_called, get_argparser_last_args, \ argparser_update_currenttask -from .binding.frameworks.pytorch_bind import PatchPyTorchModelIO -from .binding.frameworks.tensorflow_bind import TensorflowBinding -from .binding.frameworks.xgboost_bind import PatchXGBoostModelIO -from .binding.matplotlib_bind import PatchedMatplotlib -from .utilities.resource_monitor import ResourceMonitor -from .utilities.seed import make_deterministic from .utilities.dicts import ReadOnlyDict from .utilities.proxy_object import ProxyDictPreWrite, ProxyDictPostWrite, flatten_dictionary, \ nested_from_flat_dictionary +from .utilities.resource_monitor import ResourceMonitor +from .utilities.seed import make_deterministic class Task(_Task): @@ -305,6 +307,12 @@ class Task(_Task): if argparser_parseargs_called(): parser, parsed_args = get_argparser_last_args() task._connect_argparse(parser=parser, parsed_args=parsed_args) + elif argparser_parseargs_called(): + # parse_args was automatically patched, but auto_connect_arg_parser is False... + raise UsageError("ArgumentParser.parse_args() was automatically connected to this task, " + "although auto_connect_arg_parser is turned off!" + "When turning off auto_connect_arg_parser, call Task.init() " + "before calling ArgumentParser.parse_args()") # Make sure we start the logger, it will patch the main logging object and pipe all output # if we are running locally and using development mode worker, we will pipe all stdout to logger. @@ -534,15 +542,15 @@ class Task(_Task): :raise: raise exception on unsupported objects """ - dispatch = OrderedDict(( + dispatch = ( (OutputModel, self._connect_output_model), (InputModel, self._connect_input_model), (ArgumentParser, self._connect_argparse), (dict, self._connect_dictionary), (TaskParameters, self._connect_task_parameters), - )) + ) - for mutable_type, method in dispatch.items(): + for mutable_type, method in dispatch: if isinstance(mutable, mutable_type): return method(mutable) @@ -999,6 +1007,8 @@ class Task(_Task): if in_dev_mode: # update this session, for later use cls.__update_last_used_task_id(default_project_name, default_task_name, default_task_type.value, task.id) + # set default docker image from env. + task._set_default_docker_image() # mark the task as started task.started()