From b865fc00720911d3cda383bbe0ee58a936e07c14 Mon Sep 17 00:00:00 2001 From: allegroai <> Date: Sun, 31 May 2020 12:05:09 +0300 Subject: [PATCH] If sys.argv doesn't point into a git repo, take file calling Task.init(). Support running code from a module (i.e. -m module) --- .../backend_interface/task/repo/scriptinfo.py | 65 +++++++++++++++---- trains/backend_interface/task/task.py | 13 +++- trains/binding/artifacts.py | 7 +- trains/task.py | 13 ++++ trains/utilities/pigar/__main__.py | 12 ++-- trains/utilities/pigar/reqs.py | 17 ++++- 6 files changed, 106 insertions(+), 21 deletions(-) diff --git a/trains/backend_interface/task/repo/scriptinfo.py b/trains/backend_interface/task/repo/scriptinfo.py index 0f3d8885..4b6d3a2b 100644 --- a/trains/backend_interface/task/repo/scriptinfo.py +++ b/trains/backend_interface/task/repo/scriptinfo.py @@ -83,6 +83,10 @@ class ScriptRequirements(object): except Exception: pass + # remove setuptools, we should not specify this module version. It is installed by default + if 'setuptools' in modules: + modules.pop('setuptools', {}) + # add forced requirements: # noinspection PyBroadException try: @@ -494,18 +498,18 @@ class ScriptInfo(object): return '' @classmethod - def _get_script_info(cls, filepath, check_uncommitted=True, create_requirements=True, log=None): + def _get_script_info(cls, filepaths, check_uncommitted=True, create_requirements=True, log=None): jupyter_filepath = cls._get_jupyter_notebook_filename() if jupyter_filepath: - script_path = Path(os.path.normpath(jupyter_filepath)).absolute() + scripts_path = [Path(os.path.normpath(jupyter_filepath)).absolute()] else: - script_path = Path(os.path.normpath(filepath)).absolute() - if not script_path.is_file(): + scripts_path = [Path(os.path.normpath(f)).absolute() for f in filepaths if f] + if all(not f.is_file() for f in scripts_path): raise ScriptInfoError( - "Script file [{}] could not be found".format(filepath) + "Script file {} could not be found".format(scripts_path) ) - script_dir = script_path.parent + scripts_dir = [f.parent for f in scripts_path] def _log(msg, *args, **kwargs): if not log: @@ -516,18 +520,25 @@ class ScriptInfo(object): ) ) - plugin = next((p for p in cls.plugins if p.exists(script_dir)), None) + plugin = next((p for p in cls.plugins if any(p.exists(d) for d in scripts_dir)), None) repo_info = DetectionResult() + script_dir = scripts_dir[0] + script_path = scripts_path[0] if not plugin: log.info("No repository found, storing script code instead") else: try: - repo_info = plugin.get_info(str(script_dir), include_diff=check_uncommitted) + for i, d in enumerate(scripts_dir): + repo_info = plugin.get_info(str(d), include_diff=check_uncommitted) + if not repo_info.is_empty(): + script_dir = d + script_path = scripts_path[i] + break except Exception as ex: - _log("no info for {} ({})", script_dir, ex) + _log("no info for {} ({})", scripts_dir, ex) else: if repo_info.is_empty(): - _log("no info for {}", script_dir) + _log("no info for {}", scripts_dir) repo_root = repo_info.root or script_dir if not plugin: @@ -564,6 +575,8 @@ class ScriptInfo(object): diff=diff, requirements={'pip': requirements, 'conda': conda_requirements} if requirements else None, binary='python{}.{}'.format(sys.version_info.major, sys.version_info.minor), + repo_root=repo_root, + jupyter_filepath=jupyter_filepath, ) messages = [] @@ -581,16 +594,44 @@ class ScriptInfo(object): script_requirements) @classmethod - def get(cls, filepath=sys.argv[0], check_uncommitted=True, create_requirements=True, log=None): + def get(cls, filepaths=None, check_uncommitted=True, create_requirements=True, log=None): try: + if not filepaths: + filepaths = [sys.argv[0], ] return cls._get_script_info( - filepath=filepath, check_uncommitted=check_uncommitted, + filepaths=filepaths, check_uncommitted=check_uncommitted, create_requirements=create_requirements, log=log) except Exception as ex: if log: log.warning("Failed auto-detecting task repository: {}".format(ex)) return ScriptInfoResult(), None + @classmethod + def detect_running_module(cls, script_dict): + # noinspection PyBroadException + try: + # If this is jupyter, do not try to detect the running module, we know what we have. + if script_dict.get('jupyter_filepath'): + return script_dict + + if '__main__' in sys.modules and vars(sys.modules['__main__'])['__package__']: + argvs = '' + git_root = os.path.abspath(script_dict['repo_root']) if script_dict['repo_root'] else None + for a in sys.argv[1:]: + if git_root and os.path.exists(a): + # check if common to project: + a_abs = os.path.abspath(a) + if os.path.commonpath([a_abs, git_root]) == git_root: + # adjust path relative to working dir inside git repo + a = ' ' + os.path.relpath(a_abs, os.path.join(git_root, script_dict['working_dir'])) + argvs += ' {}'.format(a) + # update the script entry point to match the real argv and module call + script_dict['entry_point'] = '-m {}{}'.format( + vars(sys.modules['__main__'])['__package__'], (' ' + argvs) if argvs else '') + except Exception: + pass + return script_dict + @classmethod def close(cls): _JupyterObserver.close() diff --git a/trains/backend_interface/task/task.py b/trains/backend_interface/task/task.py index 6c6ef490..66da7fa5 100644 --- a/trains/backend_interface/task/task.py +++ b/trains/backend_interface/task/task.py @@ -2,6 +2,7 @@ import itertools import logging import os +import sys import re from enum import Enum from tempfile import gettempdir @@ -247,20 +248,28 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin): check_package_update_thread.start() # do not request requirements, because it might be a long process, and we first want to update the git repo result, script_requirements = ScriptInfo.get( + filepaths=[self._calling_filename, sys.argv[0], ], log=self.log, create_requirements=False, check_uncommitted=self._store_diff ) for msg in result.warning_messages: self.get_logger().report_text(msg) + # store original entry point + entry_point = result.script.get('entry_point') if result.script else None + + # check if we are running inside a module, then we should set our entrypoint + # to the module call including all argv's + result.script = ScriptInfo.detect_running_module(result.script) + self.data.script = result.script - # Since we might run asynchronously, don't use self.data (lest someone else + # Since we might run asynchronously, don't use self.data (let someone else # overwrite it before we have a chance to call edit) self._edit(script=result.script) self.reload() # if jupyter is present, requirements will be created in the background, when saving a snapshot if result.script and script_requirements: entry_point_filename = None if config.get('development.force_analyze_entire_repo', False) else \ - os.path.join(result.script['working_dir'], result.script['entry_point']) + os.path.join(result.script['working_dir'], entry_point) requirements, conda_requirements = script_requirements.get_requirements( entry_point_filename=entry_point_filename) diff --git a/trains/binding/artifacts.py b/trains/binding/artifacts.py index 8575bb32..42e96ddb 100644 --- a/trains/binding/artifacts.py +++ b/trains/binding/artifacts.py @@ -35,6 +35,10 @@ try: import numpy as np except ImportError: np = None +try: + from pathlib import Path as pathlib_Path +except ImportError: + pathlib_Path = None class Artifact(object): @@ -353,7 +357,8 @@ class Artifacts(object): uri = artifact_object artifact_type = 'custom' artifact_type_data.content_type = mimetypes.guess_type(artifact_object)[0] - elif isinstance(artifact_object, six.string_types + (Path,)): + elif isinstance( + artifact_object, six.string_types + (Path, pathlib_Path,) if pathlib_Path is not None else (Path,)): # check if single file artifact_object = Path(artifact_object) diff --git a/trains/task.py b/trains/task.py index 63771303..f23aaab8 100644 --- a/trains/task.py +++ b/trains/task.py @@ -147,6 +147,7 @@ class Task(_Task): self._detect_repo_async_thread = None self._resource_monitor = None self._artifacts_manager = Artifacts(self) + self._calling_filename = None # register atexit, so that we mark the task as stopped self._at_exit_called = False @@ -1542,6 +1543,18 @@ class Task(_Task): # update current repository and put warning into logs if detect_repo: + # noinspection PyBroadException + try: + import traceback + stack = traceback.extract_stack(limit=10) + # NOTICE WE ARE ALWAYS 3 down from caller in stack! + for i in range(len(stack)-1, 0, -1): + # look for the Task.init call, then the one above it is the callee module + if stack[i].name == 'init': + task._calling_filename = os.path.abspath(stack[i-1].filename) + break + except Exception: + pass if in_dev_mode and cls.__detect_repo_async: task._detect_repo_async_thread = threading.Thread(target=task._update_repository) task._detect_repo_async_thread.daemon = True diff --git a/trains/utilities/pigar/__main__.py b/trains/utilities/pigar/__main__.py index a63c7cec..02e4c8a1 100644 --- a/trains/utilities/pigar/__main__.py +++ b/trains/utilities/pigar/__main__.py @@ -5,7 +5,7 @@ from __future__ import print_function, division, absolute_import import os import codecs -from .reqs import project_import_modules, is_std_or_local_lib +from .reqs import project_import_modules, is_std_or_local_lib, is_base_module from .utils import lines_diff from .log import logger from .modules import ReqsModules @@ -82,16 +82,20 @@ class GenerateReqs(object): guess.add(name, 0, modules[name]) # add local modules, so we know what is used but not installed. + project_path = os.path.realpath(self._project_path) for name in self._local_mods: if name in modules: if name in self._force_modules_reqs: reqs.add(name, self._force_modules_reqs[name], modules[name]) continue + # if this is a base module, we have it in installed modules but package name is None + mod_path = os.path.realpath(self._local_mods[name]) + if is_base_module(mod_path): + continue + # if this is a folder of our project, we can safely ignore it - if os.path.commonpath([os.path.realpath(self._project_path)]) == \ - os.path.commonpath([os.path.realpath(self._project_path), - os.path.realpath(self._local_mods[name])]): + if os.path.commonpath([project_path]) == os.path.commonpath([project_path, mod_path]): continue relpath = os.path.relpath(self._local_mods[name], self._project_path) diff --git a/trains/utilities/pigar/reqs.py b/trains/utilities/pigar/reqs.py index d23f0cf3..d0560db5 100644 --- a/trains/utilities/pigar/reqs.py +++ b/trains/utilities/pigar/reqs.py @@ -333,12 +333,24 @@ def get_installed_pkgs_detail(): mapping = new_mapping # HACK: prefer tensorflow_gpu over tensorflow - if 'tensorflow_gpu' in new_mapping: - new_mapping['tensorflow'] = new_mapping['tensorflow_gpu'] + if 'tensorflow_gpu' in mapping: + mapping['tensorflow'] = mapping['tensorflow_gpu'] return mapping +def is_base_module(module_path): + python_base = '{}python{}.{}'.format(os.sep, sys.version_info.major, sys.version_info.minor) + for path in sys.path: + if os.path.isdir(path) and path.rstrip('/').endswith( + (python_base, )): + if not path[-1] == os.sep: + path += os.sep + if module_path.startswith(path): + return True + return False + + def _search_path(path): mapping = dict() @@ -424,4 +436,5 @@ def _search_path(path): with open(top_level, 'r') as f: for line in f: mapping[line.strip()] = ('-e', git_url) + return mapping