If sys.argv doesn't point into a git repo, take file calling Task.init(). Support running code from a module (i.e. -m module)

This commit is contained in:
allegroai 2020-05-31 12:05:09 +03:00
parent 183ad248cf
commit b865fc0072
6 changed files with 106 additions and 21 deletions

View File

@ -83,6 +83,10 @@ class ScriptRequirements(object):
except Exception: except Exception:
pass pass
# remove setuptools, we should not specify this module version. It is installed by default
if 'setuptools' in modules:
modules.pop('setuptools', {})
# add forced requirements: # add forced requirements:
# noinspection PyBroadException # noinspection PyBroadException
try: try:
@ -494,18 +498,18 @@ class ScriptInfo(object):
return '' return ''
@classmethod @classmethod
def _get_script_info(cls, filepath, check_uncommitted=True, create_requirements=True, log=None): def _get_script_info(cls, filepaths, check_uncommitted=True, create_requirements=True, log=None):
jupyter_filepath = cls._get_jupyter_notebook_filename() jupyter_filepath = cls._get_jupyter_notebook_filename()
if jupyter_filepath: if jupyter_filepath:
script_path = Path(os.path.normpath(jupyter_filepath)).absolute() scripts_path = [Path(os.path.normpath(jupyter_filepath)).absolute()]
else: else:
script_path = Path(os.path.normpath(filepath)).absolute() scripts_path = [Path(os.path.normpath(f)).absolute() for f in filepaths if f]
if not script_path.is_file(): if all(not f.is_file() for f in scripts_path):
raise ScriptInfoError( raise ScriptInfoError(
"Script file [{}] could not be found".format(filepath) "Script file {} could not be found".format(scripts_path)
) )
script_dir = script_path.parent scripts_dir = [f.parent for f in scripts_path]
def _log(msg, *args, **kwargs): def _log(msg, *args, **kwargs):
if not log: if not log:
@ -516,18 +520,25 @@ class ScriptInfo(object):
) )
) )
plugin = next((p for p in cls.plugins if p.exists(script_dir)), None) plugin = next((p for p in cls.plugins if any(p.exists(d) for d in scripts_dir)), None)
repo_info = DetectionResult() repo_info = DetectionResult()
script_dir = scripts_dir[0]
script_path = scripts_path[0]
if not plugin: if not plugin:
log.info("No repository found, storing script code instead") log.info("No repository found, storing script code instead")
else: else:
try: try:
repo_info = plugin.get_info(str(script_dir), include_diff=check_uncommitted) for i, d in enumerate(scripts_dir):
repo_info = plugin.get_info(str(d), include_diff=check_uncommitted)
if not repo_info.is_empty():
script_dir = d
script_path = scripts_path[i]
break
except Exception as ex: except Exception as ex:
_log("no info for {} ({})", script_dir, ex) _log("no info for {} ({})", scripts_dir, ex)
else: else:
if repo_info.is_empty(): if repo_info.is_empty():
_log("no info for {}", script_dir) _log("no info for {}", scripts_dir)
repo_root = repo_info.root or script_dir repo_root = repo_info.root or script_dir
if not plugin: if not plugin:
@ -564,6 +575,8 @@ class ScriptInfo(object):
diff=diff, diff=diff,
requirements={'pip': requirements, 'conda': conda_requirements} if requirements else None, requirements={'pip': requirements, 'conda': conda_requirements} if requirements else None,
binary='python{}.{}'.format(sys.version_info.major, sys.version_info.minor), binary='python{}.{}'.format(sys.version_info.major, sys.version_info.minor),
repo_root=repo_root,
jupyter_filepath=jupyter_filepath,
) )
messages = [] messages = []
@ -581,16 +594,44 @@ class ScriptInfo(object):
script_requirements) script_requirements)
@classmethod @classmethod
def get(cls, filepath=sys.argv[0], check_uncommitted=True, create_requirements=True, log=None): def get(cls, filepaths=None, check_uncommitted=True, create_requirements=True, log=None):
try: try:
if not filepaths:
filepaths = [sys.argv[0], ]
return cls._get_script_info( return cls._get_script_info(
filepath=filepath, check_uncommitted=check_uncommitted, filepaths=filepaths, check_uncommitted=check_uncommitted,
create_requirements=create_requirements, log=log) create_requirements=create_requirements, log=log)
except Exception as ex: except Exception as ex:
if log: if log:
log.warning("Failed auto-detecting task repository: {}".format(ex)) log.warning("Failed auto-detecting task repository: {}".format(ex))
return ScriptInfoResult(), None return ScriptInfoResult(), None
@classmethod
def detect_running_module(cls, script_dict):
# noinspection PyBroadException
try:
# If this is jupyter, do not try to detect the running module, we know what we have.
if script_dict.get('jupyter_filepath'):
return script_dict
if '__main__' in sys.modules and vars(sys.modules['__main__'])['__package__']:
argvs = ''
git_root = os.path.abspath(script_dict['repo_root']) if script_dict['repo_root'] else None
for a in sys.argv[1:]:
if git_root and os.path.exists(a):
# check if common to project:
a_abs = os.path.abspath(a)
if os.path.commonpath([a_abs, git_root]) == git_root:
# adjust path relative to working dir inside git repo
a = ' ' + os.path.relpath(a_abs, os.path.join(git_root, script_dict['working_dir']))
argvs += ' {}'.format(a)
# update the script entry point to match the real argv and module call
script_dict['entry_point'] = '-m {}{}'.format(
vars(sys.modules['__main__'])['__package__'], (' ' + argvs) if argvs else '')
except Exception:
pass
return script_dict
@classmethod @classmethod
def close(cls): def close(cls):
_JupyterObserver.close() _JupyterObserver.close()

View File

@ -2,6 +2,7 @@
import itertools import itertools
import logging import logging
import os import os
import sys
import re import re
from enum import Enum from enum import Enum
from tempfile import gettempdir from tempfile import gettempdir
@ -247,20 +248,28 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
check_package_update_thread.start() check_package_update_thread.start()
# do not request requirements, because it might be a long process, and we first want to update the git repo # do not request requirements, because it might be a long process, and we first want to update the git repo
result, script_requirements = ScriptInfo.get( result, script_requirements = ScriptInfo.get(
filepaths=[self._calling_filename, sys.argv[0], ],
log=self.log, create_requirements=False, check_uncommitted=self._store_diff log=self.log, create_requirements=False, check_uncommitted=self._store_diff
) )
for msg in result.warning_messages: for msg in result.warning_messages:
self.get_logger().report_text(msg) self.get_logger().report_text(msg)
# store original entry point
entry_point = result.script.get('entry_point') if result.script else None
# check if we are running inside a module, then we should set our entrypoint
# to the module call including all argv's
result.script = ScriptInfo.detect_running_module(result.script)
self.data.script = result.script self.data.script = result.script
# Since we might run asynchronously, don't use self.data (lest someone else # Since we might run asynchronously, don't use self.data (let someone else
# overwrite it before we have a chance to call edit) # overwrite it before we have a chance to call edit)
self._edit(script=result.script) self._edit(script=result.script)
self.reload() self.reload()
# if jupyter is present, requirements will be created in the background, when saving a snapshot # if jupyter is present, requirements will be created in the background, when saving a snapshot
if result.script and script_requirements: if result.script and script_requirements:
entry_point_filename = None if config.get('development.force_analyze_entire_repo', False) else \ entry_point_filename = None if config.get('development.force_analyze_entire_repo', False) else \
os.path.join(result.script['working_dir'], result.script['entry_point']) os.path.join(result.script['working_dir'], entry_point)
requirements, conda_requirements = script_requirements.get_requirements( requirements, conda_requirements = script_requirements.get_requirements(
entry_point_filename=entry_point_filename) entry_point_filename=entry_point_filename)

View File

@ -35,6 +35,10 @@ try:
import numpy as np import numpy as np
except ImportError: except ImportError:
np = None np = None
try:
from pathlib import Path as pathlib_Path
except ImportError:
pathlib_Path = None
class Artifact(object): class Artifact(object):
@ -353,7 +357,8 @@ class Artifacts(object):
uri = artifact_object uri = artifact_object
artifact_type = 'custom' artifact_type = 'custom'
artifact_type_data.content_type = mimetypes.guess_type(artifact_object)[0] artifact_type_data.content_type = mimetypes.guess_type(artifact_object)[0]
elif isinstance(artifact_object, six.string_types + (Path,)): elif isinstance(
artifact_object, six.string_types + (Path, pathlib_Path,) if pathlib_Path is not None else (Path,)):
# check if single file # check if single file
artifact_object = Path(artifact_object) artifact_object = Path(artifact_object)

View File

@ -147,6 +147,7 @@ class Task(_Task):
self._detect_repo_async_thread = None self._detect_repo_async_thread = None
self._resource_monitor = None self._resource_monitor = None
self._artifacts_manager = Artifacts(self) self._artifacts_manager = Artifacts(self)
self._calling_filename = None
# register atexit, so that we mark the task as stopped # register atexit, so that we mark the task as stopped
self._at_exit_called = False self._at_exit_called = False
@ -1542,6 +1543,18 @@ class Task(_Task):
# update current repository and put warning into logs # update current repository and put warning into logs
if detect_repo: if detect_repo:
# noinspection PyBroadException
try:
import traceback
stack = traceback.extract_stack(limit=10)
# NOTICE WE ARE ALWAYS 3 down from caller in stack!
for i in range(len(stack)-1, 0, -1):
# look for the Task.init call, then the one above it is the callee module
if stack[i].name == 'init':
task._calling_filename = os.path.abspath(stack[i-1].filename)
break
except Exception:
pass
if in_dev_mode and cls.__detect_repo_async: if in_dev_mode and cls.__detect_repo_async:
task._detect_repo_async_thread = threading.Thread(target=task._update_repository) task._detect_repo_async_thread = threading.Thread(target=task._update_repository)
task._detect_repo_async_thread.daemon = True task._detect_repo_async_thread.daemon = True

View File

@ -5,7 +5,7 @@ from __future__ import print_function, division, absolute_import
import os import os
import codecs import codecs
from .reqs import project_import_modules, is_std_or_local_lib from .reqs import project_import_modules, is_std_or_local_lib, is_base_module
from .utils import lines_diff from .utils import lines_diff
from .log import logger from .log import logger
from .modules import ReqsModules from .modules import ReqsModules
@ -82,16 +82,20 @@ class GenerateReqs(object):
guess.add(name, 0, modules[name]) guess.add(name, 0, modules[name])
# add local modules, so we know what is used but not installed. # add local modules, so we know what is used but not installed.
project_path = os.path.realpath(self._project_path)
for name in self._local_mods: for name in self._local_mods:
if name in modules: if name in modules:
if name in self._force_modules_reqs: if name in self._force_modules_reqs:
reqs.add(name, self._force_modules_reqs[name], modules[name]) reqs.add(name, self._force_modules_reqs[name], modules[name])
continue continue
# if this is a base module, we have it in installed modules but package name is None
mod_path = os.path.realpath(self._local_mods[name])
if is_base_module(mod_path):
continue
# if this is a folder of our project, we can safely ignore it # if this is a folder of our project, we can safely ignore it
if os.path.commonpath([os.path.realpath(self._project_path)]) == \ if os.path.commonpath([project_path]) == os.path.commonpath([project_path, mod_path]):
os.path.commonpath([os.path.realpath(self._project_path),
os.path.realpath(self._local_mods[name])]):
continue continue
relpath = os.path.relpath(self._local_mods[name], self._project_path) relpath = os.path.relpath(self._local_mods[name], self._project_path)

View File

@ -333,12 +333,24 @@ def get_installed_pkgs_detail():
mapping = new_mapping mapping = new_mapping
# HACK: prefer tensorflow_gpu over tensorflow # HACK: prefer tensorflow_gpu over tensorflow
if 'tensorflow_gpu' in new_mapping: if 'tensorflow_gpu' in mapping:
new_mapping['tensorflow'] = new_mapping['tensorflow_gpu'] mapping['tensorflow'] = mapping['tensorflow_gpu']
return mapping return mapping
def is_base_module(module_path):
python_base = '{}python{}.{}'.format(os.sep, sys.version_info.major, sys.version_info.minor)
for path in sys.path:
if os.path.isdir(path) and path.rstrip('/').endswith(
(python_base, )):
if not path[-1] == os.sep:
path += os.sep
if module_path.startswith(path):
return True
return False
def _search_path(path): def _search_path(path):
mapping = dict() mapping = dict()
@ -424,4 +436,5 @@ def _search_path(path):
with open(top_level, 'r') as f: with open(top_level, 'r') as f:
for line in f: for line in f:
mapping[line.strip()] = ('-e', git_url) mapping[line.strip()] = ('-e', git_url)
return mapping return mapping