mirror of
https://github.com/clearml/clearml
synced 2025-03-03 10:42:00 +00:00
Improve conda support
This commit is contained in:
parent
9a3e130700
commit
c5dd762d9b
@ -5,10 +5,12 @@ from tempfile import mkstemp
|
|||||||
import attr
|
import attr
|
||||||
import collections
|
import collections
|
||||||
import logging
|
import logging
|
||||||
|
import json
|
||||||
from furl import furl
|
from furl import furl
|
||||||
from pathlib2 import Path
|
from pathlib2 import Path
|
||||||
from threading import Thread, Event
|
from threading import Thread, Event
|
||||||
|
|
||||||
|
from .util import get_command_output
|
||||||
from ....backend_api import Session
|
from ....backend_api import Session
|
||||||
from ....debugging import get_logger
|
from ....debugging import get_logger
|
||||||
from .detectors import GitEnvDetector, GitDetector, HgEnvDetector, HgDetector, Result as DetectionResult
|
from .detectors import GitEnvDetector, GitDetector, HgEnvDetector, HgDetector, Result as DetectionResult
|
||||||
@ -63,7 +65,7 @@ class ScriptRequirements(object):
|
|||||||
reqs, try_imports, guess = gr.extract_reqs()
|
reqs, try_imports, guess = gr.extract_reqs()
|
||||||
return self.create_requirements_txt(reqs)
|
return self.create_requirements_txt(reqs)
|
||||||
except Exception:
|
except Exception:
|
||||||
return ''
|
return '', ''
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _patched_project_import_modules(project_path, ignores):
|
def _patched_project_import_modules(project_path, ignores):
|
||||||
@ -115,6 +117,10 @@ class ScriptRequirements(object):
|
|||||||
modules |= fmodules
|
modules |= fmodules
|
||||||
try_imports |= try_ipts
|
try_imports |= try_ipts
|
||||||
|
|
||||||
|
return ScriptRequirements.add_trains_used_packages(modules), try_imports, local_mods
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def add_trains_used_packages(modules):
|
||||||
# hack: forcefully insert storage modules if we have them
|
# hack: forcefully insert storage modules if we have them
|
||||||
# noinspection PyBroadException
|
# noinspection PyBroadException
|
||||||
try:
|
try:
|
||||||
@ -146,11 +152,34 @@ class ScriptRequirements(object):
|
|||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
return modules, try_imports, local_mods
|
return modules
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def create_requirements_txt(reqs):
|
def create_requirements_txt(reqs):
|
||||||
# write requirements.txt
|
# write requirements.txt
|
||||||
|
try:
|
||||||
|
conda_requirements = ''
|
||||||
|
conda_prefix = os.environ.get('CONDA_PREFIX')
|
||||||
|
if conda_prefix and not conda_prefix.endswith(os.path.sep):
|
||||||
|
conda_prefix += os.path.sep
|
||||||
|
if conda_prefix and sys.executable.startswith(conda_prefix):
|
||||||
|
conda_packages_json = get_command_output(['conda', 'list', '--json'])
|
||||||
|
conda_packages_json = json.loads(conda_packages_json)
|
||||||
|
reqs_lower = {k.lower(): (k, v) for k, v in reqs.items()}
|
||||||
|
for r in conda_packages_json:
|
||||||
|
# check if this is a pypi package, if it is, leave it outside
|
||||||
|
if not r.get('channel') or r.get('channel') == 'pypi':
|
||||||
|
continue
|
||||||
|
# check if we have it in our required packages
|
||||||
|
name = r['name'].lower().replace('-', '_')
|
||||||
|
# hack support pytorch/torch different naming convention
|
||||||
|
if name == 'pytorch':
|
||||||
|
name = 'torch'
|
||||||
|
k, v = reqs_lower.get(name, (None, None))
|
||||||
|
if k:
|
||||||
|
conda_requirements += '{0} {1} {2}\n'.format(k, '==', v.version)
|
||||||
|
except:
|
||||||
|
conda_requirements = ''
|
||||||
|
|
||||||
# python version header
|
# python version header
|
||||||
requirements_txt = '# Python ' + sys.version.replace('\n', ' ').replace('\r', ' ') + '\n'
|
requirements_txt = '# Python ' + sys.version.replace('\n', ' ').replace('\r', ' ') + '\n'
|
||||||
@ -178,7 +207,7 @@ class ScriptRequirements(object):
|
|||||||
requirements_txt += '# IMPORT PACKAGE {0}\n'.format(k)
|
requirements_txt += '# IMPORT PACKAGE {0}\n'.format(k)
|
||||||
requirements_txt += ''.join(['# {0}\n'.format(c) for c in v.comments.sorted_items()])
|
requirements_txt += ''.join(['# {0}\n'.format(c) for c in v.comments.sorted_items()])
|
||||||
|
|
||||||
return requirements_txt
|
return requirements_txt, conda_requirements
|
||||||
|
|
||||||
|
|
||||||
class _JupyterObserver(object):
|
class _JupyterObserver(object):
|
||||||
@ -206,6 +235,15 @@ class _JupyterObserver(object):
|
|||||||
def signal_sync(cls, *_):
|
def signal_sync(cls, *_):
|
||||||
cls._sync_event.set()
|
cls._sync_event.set()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def close(cls):
|
||||||
|
if not cls._thread:
|
||||||
|
return
|
||||||
|
cls._exit_event.set()
|
||||||
|
cls._sync_event.set()
|
||||||
|
cls._thread.join()
|
||||||
|
cls._thread = None
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _daemon(cls, jupyter_notebook_filename):
|
def _daemon(cls, jupyter_notebook_filename):
|
||||||
from trains import Task
|
from trains import Task
|
||||||
@ -245,13 +283,11 @@ class _JupyterObserver(object):
|
|||||||
last_update_ts = None
|
last_update_ts = None
|
||||||
counter = 0
|
counter = 0
|
||||||
prev_script_hash = None
|
prev_script_hash = None
|
||||||
# main observer loop
|
# main observer loop, check if we need to exit
|
||||||
while True:
|
while not cls._exit_event.wait(timeout=0.):
|
||||||
# wait for timeout or sync event
|
# wait for timeout or sync event
|
||||||
cls._sync_event.wait(cls._sample_frequency if counter else cls._first_sample_frequency)
|
cls._sync_event.wait(cls._sample_frequency if counter else cls._first_sample_frequency)
|
||||||
# check if we need to exit
|
|
||||||
if cls._exit_event.wait(timeout=0.):
|
|
||||||
return
|
|
||||||
cls._sync_event.clear()
|
cls._sync_event.clear()
|
||||||
counter += 1
|
counter += 1
|
||||||
# noinspection PyBroadException
|
# noinspection PyBroadException
|
||||||
@ -283,23 +319,25 @@ class _JupyterObserver(object):
|
|||||||
if prev_script_hash and prev_script_hash == current_script_hash:
|
if prev_script_hash and prev_script_hash == current_script_hash:
|
||||||
continue
|
continue
|
||||||
requirements_txt = ''
|
requirements_txt = ''
|
||||||
|
conda_requirements = ''
|
||||||
# parse jupyter python script and prepare pip requirements (pigar)
|
# parse jupyter python script and prepare pip requirements (pigar)
|
||||||
# if backend supports requirements
|
# if backend supports requirements
|
||||||
if file_import_modules and Session.check_min_api_version('2.2'):
|
if file_import_modules and Session.check_min_api_version('2.2'):
|
||||||
fmodules, _ = file_import_modules(notebook.parts[-1], script_code)
|
fmodules, _ = file_import_modules(notebook.parts[-1], script_code)
|
||||||
|
fmodules = ScriptRequirements.add_trains_used_packages(fmodules)
|
||||||
installed_pkgs = get_installed_pkgs_detail()
|
installed_pkgs = get_installed_pkgs_detail()
|
||||||
reqs = ReqsModules()
|
reqs = ReqsModules()
|
||||||
for name in fmodules:
|
for name in fmodules:
|
||||||
if name in installed_pkgs:
|
if name in installed_pkgs:
|
||||||
pkg_name, version = installed_pkgs[name]
|
pkg_name, version = installed_pkgs[name]
|
||||||
reqs.add(pkg_name, version, fmodules[name])
|
reqs.add(pkg_name, version, fmodules[name])
|
||||||
requirements_txt = ScriptRequirements.create_requirements_txt(reqs)
|
requirements_txt, conda_requirements = ScriptRequirements.create_requirements_txt(reqs)
|
||||||
|
|
||||||
# update script
|
# update script
|
||||||
prev_script_hash = current_script_hash
|
prev_script_hash = current_script_hash
|
||||||
data_script = task.data.script
|
data_script = task.data.script
|
||||||
data_script.diff = script_code
|
data_script.diff = script_code
|
||||||
data_script.requirements = {'pip': requirements_txt}
|
data_script.requirements = {'pip': requirements_txt, 'conda': conda_requirements}
|
||||||
task._update_script(script=data_script)
|
task._update_script(script=data_script)
|
||||||
# update requirements
|
# update requirements
|
||||||
task._update_requirements(requirements=requirements_txt)
|
task._update_requirements(requirements=requirements_txt)
|
||||||
@ -491,8 +529,13 @@ class ScriptInfo(object):
|
|||||||
_log("no info for {}", script_dir)
|
_log("no info for {}", script_dir)
|
||||||
|
|
||||||
repo_root = repo_info.root or script_dir
|
repo_root = repo_info.root or script_dir
|
||||||
working_dir = cls._get_working_dir(repo_root)
|
if not plugin:
|
||||||
entry_point = cls._get_entry_point(repo_root, script_path)
|
working_dir = '.'
|
||||||
|
entry_point = str(script_path.name)
|
||||||
|
else:
|
||||||
|
working_dir = cls._get_working_dir(repo_root)
|
||||||
|
entry_point = cls._get_entry_point(repo_root, script_path)
|
||||||
|
|
||||||
if check_uncommitted:
|
if check_uncommitted:
|
||||||
diff = cls._get_script_code(script_path.as_posix()) \
|
diff = cls._get_script_code(script_path.as_posix()) \
|
||||||
if not plugin or not repo_info.commit else repo_info.diff
|
if not plugin or not repo_info.commit else repo_info.diff
|
||||||
@ -500,16 +543,18 @@ class ScriptInfo(object):
|
|||||||
diff = ''
|
diff = ''
|
||||||
# if this is not jupyter, get the requirements.txt
|
# if this is not jupyter, get the requirements.txt
|
||||||
requirements = ''
|
requirements = ''
|
||||||
|
conda_requirements = ''
|
||||||
# create requirements if backend supports requirements
|
# create requirements if backend supports requirements
|
||||||
# if jupyter is present, requirements will be created in the background, when saving a snapshot
|
# if jupyter is present, requirements will be created in the background, when saving a snapshot
|
||||||
if not jupyter_filepath and Session.check_min_api_version('2.2'):
|
if not jupyter_filepath and Session.check_min_api_version('2.2'):
|
||||||
script_requirements = ScriptRequirements(
|
script_requirements = ScriptRequirements(
|
||||||
Path(repo_root).as_posix() if repo_info.url else script_path.as_posix())
|
Path(repo_root).as_posix() if repo_info.url else script_path.as_posix())
|
||||||
if create_requirements:
|
if create_requirements:
|
||||||
requirements = script_requirements.get_requirements()
|
requirements, conda_requirements = script_requirements.get_requirements()
|
||||||
else:
|
else:
|
||||||
script_requirements = None
|
script_requirements = None
|
||||||
|
|
||||||
|
|
||||||
script_info = dict(
|
script_info = dict(
|
||||||
repository=furl(repo_info.url).remove(username=True, password=True).tostr(),
|
repository=furl(repo_info.url).remove(username=True, password=True).tostr(),
|
||||||
branch=repo_info.branch,
|
branch=repo_info.branch,
|
||||||
@ -517,7 +562,7 @@ class ScriptInfo(object):
|
|||||||
entry_point=entry_point,
|
entry_point=entry_point,
|
||||||
working_dir=working_dir,
|
working_dir=working_dir,
|
||||||
diff=diff,
|
diff=diff,
|
||||||
requirements={'pip': requirements} if requirements else None,
|
requirements={'pip': requirements, 'conda': conda_requirements} if requirements else None,
|
||||||
)
|
)
|
||||||
|
|
||||||
messages = []
|
messages = []
|
||||||
@ -545,6 +590,10 @@ class ScriptInfo(object):
|
|||||||
log.warning("Failed auto-detecting task repository: {}".format(ex))
|
log.warning("Failed auto-detecting task repository: {}".format(ex))
|
||||||
return ScriptInfoResult(), None
|
return ScriptInfoResult(), None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def close(cls):
|
||||||
|
_JupyterObserver.close()
|
||||||
|
|
||||||
|
|
||||||
@attr.s
|
@attr.s
|
||||||
class ScriptInfoResult(object):
|
class ScriptInfoResult(object):
|
||||||
|
Loading…
Reference in New Issue
Block a user