If sys.argv doesn't point into a git repo, take file calling Task.init(). Support running code from a module (i.e. -m module)

This commit is contained in:
allegroai 2020-05-31 12:05:09 +03:00
parent 183ad248cf
commit b865fc0072
6 changed files with 106 additions and 21 deletions
trains
backend_interface/task
binding
task.py
utilities/pigar

View File

@ -83,6 +83,10 @@ class ScriptRequirements(object):
except Exception:
pass
# remove setuptools, we should not specify this module version. It is installed by default
if 'setuptools' in modules:
modules.pop('setuptools', {})
# add forced requirements:
# noinspection PyBroadException
try:
@ -494,18 +498,18 @@ class ScriptInfo(object):
return ''
@classmethod
def _get_script_info(cls, filepath, check_uncommitted=True, create_requirements=True, log=None):
def _get_script_info(cls, filepaths, check_uncommitted=True, create_requirements=True, log=None):
jupyter_filepath = cls._get_jupyter_notebook_filename()
if jupyter_filepath:
script_path = Path(os.path.normpath(jupyter_filepath)).absolute()
scripts_path = [Path(os.path.normpath(jupyter_filepath)).absolute()]
else:
script_path = Path(os.path.normpath(filepath)).absolute()
if not script_path.is_file():
scripts_path = [Path(os.path.normpath(f)).absolute() for f in filepaths if f]
if all(not f.is_file() for f in scripts_path):
raise ScriptInfoError(
"Script file [{}] could not be found".format(filepath)
"Script file {} could not be found".format(scripts_path)
)
script_dir = script_path.parent
scripts_dir = [f.parent for f in scripts_path]
def _log(msg, *args, **kwargs):
if not log:
@ -516,18 +520,25 @@ class ScriptInfo(object):
)
)
plugin = next((p for p in cls.plugins if p.exists(script_dir)), None)
plugin = next((p for p in cls.plugins if any(p.exists(d) for d in scripts_dir)), None)
repo_info = DetectionResult()
script_dir = scripts_dir[0]
script_path = scripts_path[0]
if not plugin:
log.info("No repository found, storing script code instead")
else:
try:
repo_info = plugin.get_info(str(script_dir), include_diff=check_uncommitted)
for i, d in enumerate(scripts_dir):
repo_info = plugin.get_info(str(d), include_diff=check_uncommitted)
if not repo_info.is_empty():
script_dir = d
script_path = scripts_path[i]
break
except Exception as ex:
_log("no info for {} ({})", script_dir, ex)
_log("no info for {} ({})", scripts_dir, ex)
else:
if repo_info.is_empty():
_log("no info for {}", script_dir)
_log("no info for {}", scripts_dir)
repo_root = repo_info.root or script_dir
if not plugin:
@ -564,6 +575,8 @@ class ScriptInfo(object):
diff=diff,
requirements={'pip': requirements, 'conda': conda_requirements} if requirements else None,
binary='python{}.{}'.format(sys.version_info.major, sys.version_info.minor),
repo_root=repo_root,
jupyter_filepath=jupyter_filepath,
)
messages = []
@ -581,16 +594,44 @@ class ScriptInfo(object):
script_requirements)
@classmethod
def get(cls, filepath=sys.argv[0], check_uncommitted=True, create_requirements=True, log=None):
def get(cls, filepaths=None, check_uncommitted=True, create_requirements=True, log=None):
try:
if not filepaths:
filepaths = [sys.argv[0], ]
return cls._get_script_info(
filepath=filepath, check_uncommitted=check_uncommitted,
filepaths=filepaths, check_uncommitted=check_uncommitted,
create_requirements=create_requirements, log=log)
except Exception as ex:
if log:
log.warning("Failed auto-detecting task repository: {}".format(ex))
return ScriptInfoResult(), None
@classmethod
def detect_running_module(cls, script_dict):
# noinspection PyBroadException
try:
# If this is jupyter, do not try to detect the running module, we know what we have.
if script_dict.get('jupyter_filepath'):
return script_dict
if '__main__' in sys.modules and vars(sys.modules['__main__'])['__package__']:
argvs = ''
git_root = os.path.abspath(script_dict['repo_root']) if script_dict['repo_root'] else None
for a in sys.argv[1:]:
if git_root and os.path.exists(a):
# check if common to project:
a_abs = os.path.abspath(a)
if os.path.commonpath([a_abs, git_root]) == git_root:
# adjust path relative to working dir inside git repo
a = ' ' + os.path.relpath(a_abs, os.path.join(git_root, script_dict['working_dir']))
argvs += ' {}'.format(a)
# update the script entry point to match the real argv and module call
script_dict['entry_point'] = '-m {}{}'.format(
vars(sys.modules['__main__'])['__package__'], (' ' + argvs) if argvs else '')
except Exception:
pass
return script_dict
@classmethod
def close(cls):
_JupyterObserver.close()

View File

@ -2,6 +2,7 @@
import itertools
import logging
import os
import sys
import re
from enum import Enum
from tempfile import gettempdir
@ -247,20 +248,28 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
check_package_update_thread.start()
# do not request requirements, because it might be a long process, and we first want to update the git repo
result, script_requirements = ScriptInfo.get(
filepaths=[self._calling_filename, sys.argv[0], ],
log=self.log, create_requirements=False, check_uncommitted=self._store_diff
)
for msg in result.warning_messages:
self.get_logger().report_text(msg)
# store original entry point
entry_point = result.script.get('entry_point') if result.script else None
# check if we are running inside a module, then we should set our entrypoint
# to the module call including all argv's
result.script = ScriptInfo.detect_running_module(result.script)
self.data.script = result.script
# Since we might run asynchronously, don't use self.data (lest someone else
# Since we might run asynchronously, don't use self.data (let someone else
# overwrite it before we have a chance to call edit)
self._edit(script=result.script)
self.reload()
# if jupyter is present, requirements will be created in the background, when saving a snapshot
if result.script and script_requirements:
entry_point_filename = None if config.get('development.force_analyze_entire_repo', False) else \
os.path.join(result.script['working_dir'], result.script['entry_point'])
os.path.join(result.script['working_dir'], entry_point)
requirements, conda_requirements = script_requirements.get_requirements(
entry_point_filename=entry_point_filename)

View File

@ -35,6 +35,10 @@ try:
import numpy as np
except ImportError:
np = None
try:
from pathlib import Path as pathlib_Path
except ImportError:
pathlib_Path = None
class Artifact(object):
@ -353,7 +357,8 @@ class Artifacts(object):
uri = artifact_object
artifact_type = 'custom'
artifact_type_data.content_type = mimetypes.guess_type(artifact_object)[0]
elif isinstance(artifact_object, six.string_types + (Path,)):
elif isinstance(
artifact_object, six.string_types + (Path, pathlib_Path,) if pathlib_Path is not None else (Path,)):
# check if single file
artifact_object = Path(artifact_object)

View File

@ -147,6 +147,7 @@ class Task(_Task):
self._detect_repo_async_thread = None
self._resource_monitor = None
self._artifacts_manager = Artifacts(self)
self._calling_filename = None
# register atexit, so that we mark the task as stopped
self._at_exit_called = False
@ -1542,6 +1543,18 @@ class Task(_Task):
# update current repository and put warning into logs
if detect_repo:
# noinspection PyBroadException
try:
import traceback
stack = traceback.extract_stack(limit=10)
# NOTICE WE ARE ALWAYS 3 down from caller in stack!
for i in range(len(stack)-1, 0, -1):
# look for the Task.init call, then the one above it is the callee module
if stack[i].name == 'init':
task._calling_filename = os.path.abspath(stack[i-1].filename)
break
except Exception:
pass
if in_dev_mode and cls.__detect_repo_async:
task._detect_repo_async_thread = threading.Thread(target=task._update_repository)
task._detect_repo_async_thread.daemon = True

View File

@ -5,7 +5,7 @@ from __future__ import print_function, division, absolute_import
import os
import codecs
from .reqs import project_import_modules, is_std_or_local_lib
from .reqs import project_import_modules, is_std_or_local_lib, is_base_module
from .utils import lines_diff
from .log import logger
from .modules import ReqsModules
@ -82,16 +82,20 @@ class GenerateReqs(object):
guess.add(name, 0, modules[name])
# add local modules, so we know what is used but not installed.
project_path = os.path.realpath(self._project_path)
for name in self._local_mods:
if name in modules:
if name in self._force_modules_reqs:
reqs.add(name, self._force_modules_reqs[name], modules[name])
continue
# if this is a base module, we have it in installed modules but package name is None
mod_path = os.path.realpath(self._local_mods[name])
if is_base_module(mod_path):
continue
# if this is a folder of our project, we can safely ignore it
if os.path.commonpath([os.path.realpath(self._project_path)]) == \
os.path.commonpath([os.path.realpath(self._project_path),
os.path.realpath(self._local_mods[name])]):
if os.path.commonpath([project_path]) == os.path.commonpath([project_path, mod_path]):
continue
relpath = os.path.relpath(self._local_mods[name], self._project_path)

View File

@ -333,12 +333,24 @@ def get_installed_pkgs_detail():
mapping = new_mapping
# HACK: prefer tensorflow_gpu over tensorflow
if 'tensorflow_gpu' in new_mapping:
new_mapping['tensorflow'] = new_mapping['tensorflow_gpu']
if 'tensorflow_gpu' in mapping:
mapping['tensorflow'] = mapping['tensorflow_gpu']
return mapping
def is_base_module(module_path):
python_base = '{}python{}.{}'.format(os.sep, sys.version_info.major, sys.version_info.minor)
for path in sys.path:
if os.path.isdir(path) and path.rstrip('/').endswith(
(python_base, )):
if not path[-1] == os.sep:
path += os.sep
if module_path.startswith(path):
return True
return False
def _search_path(path):
mapping = dict()
@ -424,4 +436,5 @@ def _search_path(path):
with open(top_level, 'r') as f:
for line in f:
mapping[line.strip()] = ('-e', git_url)
return mapping