mirror of
https://github.com/clearml/clearml
synced 2025-03-03 02:32:11 +00:00
Improve conda support
This commit is contained in:
parent
9a3e130700
commit
c5dd762d9b
@ -5,10 +5,12 @@ from tempfile import mkstemp
|
||||
import attr
|
||||
import collections
|
||||
import logging
|
||||
import json
|
||||
from furl import furl
|
||||
from pathlib2 import Path
|
||||
from threading import Thread, Event
|
||||
|
||||
from .util import get_command_output
|
||||
from ....backend_api import Session
|
||||
from ....debugging import get_logger
|
||||
from .detectors import GitEnvDetector, GitDetector, HgEnvDetector, HgDetector, Result as DetectionResult
|
||||
@ -63,7 +65,7 @@ class ScriptRequirements(object):
|
||||
reqs, try_imports, guess = gr.extract_reqs()
|
||||
return self.create_requirements_txt(reqs)
|
||||
except Exception:
|
||||
return ''
|
||||
return '', ''
|
||||
|
||||
@staticmethod
|
||||
def _patched_project_import_modules(project_path, ignores):
|
||||
@ -115,6 +117,10 @@ class ScriptRequirements(object):
|
||||
modules |= fmodules
|
||||
try_imports |= try_ipts
|
||||
|
||||
return ScriptRequirements.add_trains_used_packages(modules), try_imports, local_mods
|
||||
|
||||
@staticmethod
|
||||
def add_trains_used_packages(modules):
|
||||
# hack: forcefully insert storage modules if we have them
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
@ -146,11 +152,34 @@ class ScriptRequirements(object):
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return modules, try_imports, local_mods
|
||||
return modules
|
||||
|
||||
@staticmethod
|
||||
def create_requirements_txt(reqs):
|
||||
# write requirements.txt
|
||||
try:
|
||||
conda_requirements = ''
|
||||
conda_prefix = os.environ.get('CONDA_PREFIX')
|
||||
if conda_prefix and not conda_prefix.endswith(os.path.sep):
|
||||
conda_prefix += os.path.sep
|
||||
if conda_prefix and sys.executable.startswith(conda_prefix):
|
||||
conda_packages_json = get_command_output(['conda', 'list', '--json'])
|
||||
conda_packages_json = json.loads(conda_packages_json)
|
||||
reqs_lower = {k.lower(): (k, v) for k, v in reqs.items()}
|
||||
for r in conda_packages_json:
|
||||
# check if this is a pypi package, if it is, leave it outside
|
||||
if not r.get('channel') or r.get('channel') == 'pypi':
|
||||
continue
|
||||
# check if we have it in our required packages
|
||||
name = r['name'].lower().replace('-', '_')
|
||||
# hack support pytorch/torch different naming convention
|
||||
if name == 'pytorch':
|
||||
name = 'torch'
|
||||
k, v = reqs_lower.get(name, (None, None))
|
||||
if k:
|
||||
conda_requirements += '{0} {1} {2}\n'.format(k, '==', v.version)
|
||||
except:
|
||||
conda_requirements = ''
|
||||
|
||||
# python version header
|
||||
requirements_txt = '# Python ' + sys.version.replace('\n', ' ').replace('\r', ' ') + '\n'
|
||||
@ -178,7 +207,7 @@ class ScriptRequirements(object):
|
||||
requirements_txt += '# IMPORT PACKAGE {0}\n'.format(k)
|
||||
requirements_txt += ''.join(['# {0}\n'.format(c) for c in v.comments.sorted_items()])
|
||||
|
||||
return requirements_txt
|
||||
return requirements_txt, conda_requirements
|
||||
|
||||
|
||||
class _JupyterObserver(object):
|
||||
@ -206,6 +235,15 @@ class _JupyterObserver(object):
|
||||
def signal_sync(cls, *_):
|
||||
cls._sync_event.set()
|
||||
|
||||
@classmethod
|
||||
def close(cls):
|
||||
if not cls._thread:
|
||||
return
|
||||
cls._exit_event.set()
|
||||
cls._sync_event.set()
|
||||
cls._thread.join()
|
||||
cls._thread = None
|
||||
|
||||
@classmethod
|
||||
def _daemon(cls, jupyter_notebook_filename):
|
||||
from trains import Task
|
||||
@ -245,13 +283,11 @@ class _JupyterObserver(object):
|
||||
last_update_ts = None
|
||||
counter = 0
|
||||
prev_script_hash = None
|
||||
# main observer loop
|
||||
while True:
|
||||
# main observer loop, check if we need to exit
|
||||
while not cls._exit_event.wait(timeout=0.):
|
||||
# wait for timeout or sync event
|
||||
cls._sync_event.wait(cls._sample_frequency if counter else cls._first_sample_frequency)
|
||||
# check if we need to exit
|
||||
if cls._exit_event.wait(timeout=0.):
|
||||
return
|
||||
|
||||
cls._sync_event.clear()
|
||||
counter += 1
|
||||
# noinspection PyBroadException
|
||||
@ -283,23 +319,25 @@ class _JupyterObserver(object):
|
||||
if prev_script_hash and prev_script_hash == current_script_hash:
|
||||
continue
|
||||
requirements_txt = ''
|
||||
conda_requirements = ''
|
||||
# parse jupyter python script and prepare pip requirements (pigar)
|
||||
# if backend supports requirements
|
||||
if file_import_modules and Session.check_min_api_version('2.2'):
|
||||
fmodules, _ = file_import_modules(notebook.parts[-1], script_code)
|
||||
fmodules = ScriptRequirements.add_trains_used_packages(fmodules)
|
||||
installed_pkgs = get_installed_pkgs_detail()
|
||||
reqs = ReqsModules()
|
||||
for name in fmodules:
|
||||
if name in installed_pkgs:
|
||||
pkg_name, version = installed_pkgs[name]
|
||||
reqs.add(pkg_name, version, fmodules[name])
|
||||
requirements_txt = ScriptRequirements.create_requirements_txt(reqs)
|
||||
requirements_txt, conda_requirements = ScriptRequirements.create_requirements_txt(reqs)
|
||||
|
||||
# update script
|
||||
prev_script_hash = current_script_hash
|
||||
data_script = task.data.script
|
||||
data_script.diff = script_code
|
||||
data_script.requirements = {'pip': requirements_txt}
|
||||
data_script.requirements = {'pip': requirements_txt, 'conda': conda_requirements}
|
||||
task._update_script(script=data_script)
|
||||
# update requirements
|
||||
task._update_requirements(requirements=requirements_txt)
|
||||
@ -491,8 +529,13 @@ class ScriptInfo(object):
|
||||
_log("no info for {}", script_dir)
|
||||
|
||||
repo_root = repo_info.root or script_dir
|
||||
working_dir = cls._get_working_dir(repo_root)
|
||||
entry_point = cls._get_entry_point(repo_root, script_path)
|
||||
if not plugin:
|
||||
working_dir = '.'
|
||||
entry_point = str(script_path.name)
|
||||
else:
|
||||
working_dir = cls._get_working_dir(repo_root)
|
||||
entry_point = cls._get_entry_point(repo_root, script_path)
|
||||
|
||||
if check_uncommitted:
|
||||
diff = cls._get_script_code(script_path.as_posix()) \
|
||||
if not plugin or not repo_info.commit else repo_info.diff
|
||||
@ -500,16 +543,18 @@ class ScriptInfo(object):
|
||||
diff = ''
|
||||
# if this is not jupyter, get the requirements.txt
|
||||
requirements = ''
|
||||
conda_requirements = ''
|
||||
# create requirements if backend supports requirements
|
||||
# if jupyter is present, requirements will be created in the background, when saving a snapshot
|
||||
if not jupyter_filepath and Session.check_min_api_version('2.2'):
|
||||
script_requirements = ScriptRequirements(
|
||||
Path(repo_root).as_posix() if repo_info.url else script_path.as_posix())
|
||||
if create_requirements:
|
||||
requirements = script_requirements.get_requirements()
|
||||
requirements, conda_requirements = script_requirements.get_requirements()
|
||||
else:
|
||||
script_requirements = None
|
||||
|
||||
|
||||
script_info = dict(
|
||||
repository=furl(repo_info.url).remove(username=True, password=True).tostr(),
|
||||
branch=repo_info.branch,
|
||||
@ -517,7 +562,7 @@ class ScriptInfo(object):
|
||||
entry_point=entry_point,
|
||||
working_dir=working_dir,
|
||||
diff=diff,
|
||||
requirements={'pip': requirements} if requirements else None,
|
||||
requirements={'pip': requirements, 'conda': conda_requirements} if requirements else None,
|
||||
)
|
||||
|
||||
messages = []
|
||||
@ -545,6 +590,10 @@ class ScriptInfo(object):
|
||||
log.warning("Failed auto-detecting task repository: {}".format(ex))
|
||||
return ScriptInfoResult(), None
|
||||
|
||||
@classmethod
|
||||
def close(cls):
|
||||
_JupyterObserver.close()
|
||||
|
||||
|
||||
@attr.s
|
||||
class ScriptInfoResult(object):
|
||||
|
Loading…
Reference in New Issue
Block a user