mirror of
https://github.com/clearml/clearml
synced 2025-03-02 10:12:27 +00:00
If sys.argv doesn't point into a git repo, take file calling Task.init(). Support running code from a module (i.e. -m module)
This commit is contained in:
parent
183ad248cf
commit
b865fc0072
trains
@ -83,6 +83,10 @@ class ScriptRequirements(object):
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# remove setuptools, we should not specify this module version. It is installed by default
|
||||
if 'setuptools' in modules:
|
||||
modules.pop('setuptools', {})
|
||||
|
||||
# add forced requirements:
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
@ -494,18 +498,18 @@ class ScriptInfo(object):
|
||||
return ''
|
||||
|
||||
@classmethod
|
||||
def _get_script_info(cls, filepath, check_uncommitted=True, create_requirements=True, log=None):
|
||||
def _get_script_info(cls, filepaths, check_uncommitted=True, create_requirements=True, log=None):
|
||||
jupyter_filepath = cls._get_jupyter_notebook_filename()
|
||||
if jupyter_filepath:
|
||||
script_path = Path(os.path.normpath(jupyter_filepath)).absolute()
|
||||
scripts_path = [Path(os.path.normpath(jupyter_filepath)).absolute()]
|
||||
else:
|
||||
script_path = Path(os.path.normpath(filepath)).absolute()
|
||||
if not script_path.is_file():
|
||||
scripts_path = [Path(os.path.normpath(f)).absolute() for f in filepaths if f]
|
||||
if all(not f.is_file() for f in scripts_path):
|
||||
raise ScriptInfoError(
|
||||
"Script file [{}] could not be found".format(filepath)
|
||||
"Script file {} could not be found".format(scripts_path)
|
||||
)
|
||||
|
||||
script_dir = script_path.parent
|
||||
scripts_dir = [f.parent for f in scripts_path]
|
||||
|
||||
def _log(msg, *args, **kwargs):
|
||||
if not log:
|
||||
@ -516,18 +520,25 @@ class ScriptInfo(object):
|
||||
)
|
||||
)
|
||||
|
||||
plugin = next((p for p in cls.plugins if p.exists(script_dir)), None)
|
||||
plugin = next((p for p in cls.plugins if any(p.exists(d) for d in scripts_dir)), None)
|
||||
repo_info = DetectionResult()
|
||||
script_dir = scripts_dir[0]
|
||||
script_path = scripts_path[0]
|
||||
if not plugin:
|
||||
log.info("No repository found, storing script code instead")
|
||||
else:
|
||||
try:
|
||||
repo_info = plugin.get_info(str(script_dir), include_diff=check_uncommitted)
|
||||
for i, d in enumerate(scripts_dir):
|
||||
repo_info = plugin.get_info(str(d), include_diff=check_uncommitted)
|
||||
if not repo_info.is_empty():
|
||||
script_dir = d
|
||||
script_path = scripts_path[i]
|
||||
break
|
||||
except Exception as ex:
|
||||
_log("no info for {} ({})", script_dir, ex)
|
||||
_log("no info for {} ({})", scripts_dir, ex)
|
||||
else:
|
||||
if repo_info.is_empty():
|
||||
_log("no info for {}", script_dir)
|
||||
_log("no info for {}", scripts_dir)
|
||||
|
||||
repo_root = repo_info.root or script_dir
|
||||
if not plugin:
|
||||
@ -564,6 +575,8 @@ class ScriptInfo(object):
|
||||
diff=diff,
|
||||
requirements={'pip': requirements, 'conda': conda_requirements} if requirements else None,
|
||||
binary='python{}.{}'.format(sys.version_info.major, sys.version_info.minor),
|
||||
repo_root=repo_root,
|
||||
jupyter_filepath=jupyter_filepath,
|
||||
)
|
||||
|
||||
messages = []
|
||||
@ -581,16 +594,44 @@ class ScriptInfo(object):
|
||||
script_requirements)
|
||||
|
||||
@classmethod
|
||||
def get(cls, filepath=sys.argv[0], check_uncommitted=True, create_requirements=True, log=None):
|
||||
def get(cls, filepaths=None, check_uncommitted=True, create_requirements=True, log=None):
|
||||
try:
|
||||
if not filepaths:
|
||||
filepaths = [sys.argv[0], ]
|
||||
return cls._get_script_info(
|
||||
filepath=filepath, check_uncommitted=check_uncommitted,
|
||||
filepaths=filepaths, check_uncommitted=check_uncommitted,
|
||||
create_requirements=create_requirements, log=log)
|
||||
except Exception as ex:
|
||||
if log:
|
||||
log.warning("Failed auto-detecting task repository: {}".format(ex))
|
||||
return ScriptInfoResult(), None
|
||||
|
||||
@classmethod
|
||||
def detect_running_module(cls, script_dict):
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
# If this is jupyter, do not try to detect the running module, we know what we have.
|
||||
if script_dict.get('jupyter_filepath'):
|
||||
return script_dict
|
||||
|
||||
if '__main__' in sys.modules and vars(sys.modules['__main__'])['__package__']:
|
||||
argvs = ''
|
||||
git_root = os.path.abspath(script_dict['repo_root']) if script_dict['repo_root'] else None
|
||||
for a in sys.argv[1:]:
|
||||
if git_root and os.path.exists(a):
|
||||
# check if common to project:
|
||||
a_abs = os.path.abspath(a)
|
||||
if os.path.commonpath([a_abs, git_root]) == git_root:
|
||||
# adjust path relative to working dir inside git repo
|
||||
a = ' ' + os.path.relpath(a_abs, os.path.join(git_root, script_dict['working_dir']))
|
||||
argvs += ' {}'.format(a)
|
||||
# update the script entry point to match the real argv and module call
|
||||
script_dict['entry_point'] = '-m {}{}'.format(
|
||||
vars(sys.modules['__main__'])['__package__'], (' ' + argvs) if argvs else '')
|
||||
except Exception:
|
||||
pass
|
||||
return script_dict
|
||||
|
||||
@classmethod
|
||||
def close(cls):
|
||||
_JupyterObserver.close()
|
||||
|
@ -2,6 +2,7 @@
|
||||
import itertools
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import re
|
||||
from enum import Enum
|
||||
from tempfile import gettempdir
|
||||
@ -247,20 +248,28 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
|
||||
check_package_update_thread.start()
|
||||
# do not request requirements, because it might be a long process, and we first want to update the git repo
|
||||
result, script_requirements = ScriptInfo.get(
|
||||
filepaths=[self._calling_filename, sys.argv[0], ],
|
||||
log=self.log, create_requirements=False, check_uncommitted=self._store_diff
|
||||
)
|
||||
for msg in result.warning_messages:
|
||||
self.get_logger().report_text(msg)
|
||||
|
||||
# store original entry point
|
||||
entry_point = result.script.get('entry_point') if result.script else None
|
||||
|
||||
# check if we are running inside a module, then we should set our entrypoint
|
||||
# to the module call including all argv's
|
||||
result.script = ScriptInfo.detect_running_module(result.script)
|
||||
|
||||
self.data.script = result.script
|
||||
# Since we might run asynchronously, don't use self.data (lest someone else
|
||||
# Since we might run asynchronously, don't use self.data (let someone else
|
||||
# overwrite it before we have a chance to call edit)
|
||||
self._edit(script=result.script)
|
||||
self.reload()
|
||||
# if jupyter is present, requirements will be created in the background, when saving a snapshot
|
||||
if result.script and script_requirements:
|
||||
entry_point_filename = None if config.get('development.force_analyze_entire_repo', False) else \
|
||||
os.path.join(result.script['working_dir'], result.script['entry_point'])
|
||||
os.path.join(result.script['working_dir'], entry_point)
|
||||
requirements, conda_requirements = script_requirements.get_requirements(
|
||||
entry_point_filename=entry_point_filename)
|
||||
|
||||
|
@ -35,6 +35,10 @@ try:
|
||||
import numpy as np
|
||||
except ImportError:
|
||||
np = None
|
||||
try:
|
||||
from pathlib import Path as pathlib_Path
|
||||
except ImportError:
|
||||
pathlib_Path = None
|
||||
|
||||
|
||||
class Artifact(object):
|
||||
@ -353,7 +357,8 @@ class Artifacts(object):
|
||||
uri = artifact_object
|
||||
artifact_type = 'custom'
|
||||
artifact_type_data.content_type = mimetypes.guess_type(artifact_object)[0]
|
||||
elif isinstance(artifact_object, six.string_types + (Path,)):
|
||||
elif isinstance(
|
||||
artifact_object, six.string_types + (Path, pathlib_Path,) if pathlib_Path is not None else (Path,)):
|
||||
# check if single file
|
||||
artifact_object = Path(artifact_object)
|
||||
|
||||
|
@ -147,6 +147,7 @@ class Task(_Task):
|
||||
self._detect_repo_async_thread = None
|
||||
self._resource_monitor = None
|
||||
self._artifacts_manager = Artifacts(self)
|
||||
self._calling_filename = None
|
||||
# register atexit, so that we mark the task as stopped
|
||||
self._at_exit_called = False
|
||||
|
||||
@ -1542,6 +1543,18 @@ class Task(_Task):
|
||||
|
||||
# update current repository and put warning into logs
|
||||
if detect_repo:
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
import traceback
|
||||
stack = traceback.extract_stack(limit=10)
|
||||
# NOTICE WE ARE ALWAYS 3 down from caller in stack!
|
||||
for i in range(len(stack)-1, 0, -1):
|
||||
# look for the Task.init call, then the one above it is the callee module
|
||||
if stack[i].name == 'init':
|
||||
task._calling_filename = os.path.abspath(stack[i-1].filename)
|
||||
break
|
||||
except Exception:
|
||||
pass
|
||||
if in_dev_mode and cls.__detect_repo_async:
|
||||
task._detect_repo_async_thread = threading.Thread(target=task._update_repository)
|
||||
task._detect_repo_async_thread.daemon = True
|
||||
|
@ -5,7 +5,7 @@ from __future__ import print_function, division, absolute_import
|
||||
import os
|
||||
import codecs
|
||||
|
||||
from .reqs import project_import_modules, is_std_or_local_lib
|
||||
from .reqs import project_import_modules, is_std_or_local_lib, is_base_module
|
||||
from .utils import lines_diff
|
||||
from .log import logger
|
||||
from .modules import ReqsModules
|
||||
@ -82,16 +82,20 @@ class GenerateReqs(object):
|
||||
guess.add(name, 0, modules[name])
|
||||
|
||||
# add local modules, so we know what is used but not installed.
|
||||
project_path = os.path.realpath(self._project_path)
|
||||
for name in self._local_mods:
|
||||
if name in modules:
|
||||
if name in self._force_modules_reqs:
|
||||
reqs.add(name, self._force_modules_reqs[name], modules[name])
|
||||
continue
|
||||
|
||||
# if this is a base module, we have it in installed modules but package name is None
|
||||
mod_path = os.path.realpath(self._local_mods[name])
|
||||
if is_base_module(mod_path):
|
||||
continue
|
||||
|
||||
# if this is a folder of our project, we can safely ignore it
|
||||
if os.path.commonpath([os.path.realpath(self._project_path)]) == \
|
||||
os.path.commonpath([os.path.realpath(self._project_path),
|
||||
os.path.realpath(self._local_mods[name])]):
|
||||
if os.path.commonpath([project_path]) == os.path.commonpath([project_path, mod_path]):
|
||||
continue
|
||||
|
||||
relpath = os.path.relpath(self._local_mods[name], self._project_path)
|
||||
|
@ -333,12 +333,24 @@ def get_installed_pkgs_detail():
|
||||
mapping = new_mapping
|
||||
|
||||
# HACK: prefer tensorflow_gpu over tensorflow
|
||||
if 'tensorflow_gpu' in new_mapping:
|
||||
new_mapping['tensorflow'] = new_mapping['tensorflow_gpu']
|
||||
if 'tensorflow_gpu' in mapping:
|
||||
mapping['tensorflow'] = mapping['tensorflow_gpu']
|
||||
|
||||
return mapping
|
||||
|
||||
|
||||
def is_base_module(module_path):
|
||||
python_base = '{}python{}.{}'.format(os.sep, sys.version_info.major, sys.version_info.minor)
|
||||
for path in sys.path:
|
||||
if os.path.isdir(path) and path.rstrip('/').endswith(
|
||||
(python_base, )):
|
||||
if not path[-1] == os.sep:
|
||||
path += os.sep
|
||||
if module_path.startswith(path):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _search_path(path):
|
||||
mapping = dict()
|
||||
|
||||
@ -424,4 +436,5 @@ def _search_path(path):
|
||||
with open(top_level, 'r') as f:
|
||||
for line in f:
|
||||
mapping[line.strip()] = ('-e', git_url)
|
||||
|
||||
return mapping
|
||||
|
Loading…
Reference in New Issue
Block a user