Improve conda support

This commit is contained in:
allegroai 2020-01-21 16:32:57 +02:00
parent 9a3e130700
commit c5dd762d9b

View File

@ -5,10 +5,12 @@ from tempfile import mkstemp
import attr
import collections
import logging
import json
from furl import furl
from pathlib2 import Path
from threading import Thread, Event
from .util import get_command_output
from ....backend_api import Session
from ....debugging import get_logger
from .detectors import GitEnvDetector, GitDetector, HgEnvDetector, HgDetector, Result as DetectionResult
@ -63,7 +65,7 @@ class ScriptRequirements(object):
reqs, try_imports, guess = gr.extract_reqs()
return self.create_requirements_txt(reqs)
except Exception:
return ''
return '', ''
@staticmethod
def _patched_project_import_modules(project_path, ignores):
@ -115,6 +117,10 @@ class ScriptRequirements(object):
modules |= fmodules
try_imports |= try_ipts
return ScriptRequirements.add_trains_used_packages(modules), try_imports, local_mods
@staticmethod
def add_trains_used_packages(modules):
# hack: forcefully insert storage modules if we have them
# noinspection PyBroadException
try:
@ -146,11 +152,34 @@ class ScriptRequirements(object):
except Exception:
pass
return modules, try_imports, local_mods
return modules
@staticmethod
def create_requirements_txt(reqs):
# write requirements.txt
try:
conda_requirements = ''
conda_prefix = os.environ.get('CONDA_PREFIX')
if conda_prefix and not conda_prefix.endswith(os.path.sep):
conda_prefix += os.path.sep
if conda_prefix and sys.executable.startswith(conda_prefix):
conda_packages_json = get_command_output(['conda', 'list', '--json'])
conda_packages_json = json.loads(conda_packages_json)
reqs_lower = {k.lower(): (k, v) for k, v in reqs.items()}
for r in conda_packages_json:
# check if this is a pypi package, if it is, leave it outside
if not r.get('channel') or r.get('channel') == 'pypi':
continue
# check if we have it in our required packages
name = r['name'].lower().replace('-', '_')
# hack support pytorch/torch different naming convention
if name == 'pytorch':
name = 'torch'
k, v = reqs_lower.get(name, (None, None))
if k:
conda_requirements += '{0} {1} {2}\n'.format(k, '==', v.version)
except:
conda_requirements = ''
# python version header
requirements_txt = '# Python ' + sys.version.replace('\n', ' ').replace('\r', ' ') + '\n'
@ -178,7 +207,7 @@ class ScriptRequirements(object):
requirements_txt += '# IMPORT PACKAGE {0}\n'.format(k)
requirements_txt += ''.join(['# {0}\n'.format(c) for c in v.comments.sorted_items()])
return requirements_txt
return requirements_txt, conda_requirements
class _JupyterObserver(object):
@ -206,6 +235,15 @@ class _JupyterObserver(object):
def signal_sync(cls, *_):
cls._sync_event.set()
@classmethod
def close(cls):
if not cls._thread:
return
cls._exit_event.set()
cls._sync_event.set()
cls._thread.join()
cls._thread = None
@classmethod
def _daemon(cls, jupyter_notebook_filename):
from trains import Task
@ -245,13 +283,11 @@ class _JupyterObserver(object):
last_update_ts = None
counter = 0
prev_script_hash = None
# main observer loop
while True:
# main observer loop, check if we need to exit
while not cls._exit_event.wait(timeout=0.):
# wait for timeout or sync event
cls._sync_event.wait(cls._sample_frequency if counter else cls._first_sample_frequency)
# check if we need to exit
if cls._exit_event.wait(timeout=0.):
return
cls._sync_event.clear()
counter += 1
# noinspection PyBroadException
@ -283,23 +319,25 @@ class _JupyterObserver(object):
if prev_script_hash and prev_script_hash == current_script_hash:
continue
requirements_txt = ''
conda_requirements = ''
# parse jupyter python script and prepare pip requirements (pigar)
# if backend supports requirements
if file_import_modules and Session.check_min_api_version('2.2'):
fmodules, _ = file_import_modules(notebook.parts[-1], script_code)
fmodules = ScriptRequirements.add_trains_used_packages(fmodules)
installed_pkgs = get_installed_pkgs_detail()
reqs = ReqsModules()
for name in fmodules:
if name in installed_pkgs:
pkg_name, version = installed_pkgs[name]
reqs.add(pkg_name, version, fmodules[name])
requirements_txt = ScriptRequirements.create_requirements_txt(reqs)
requirements_txt, conda_requirements = ScriptRequirements.create_requirements_txt(reqs)
# update script
prev_script_hash = current_script_hash
data_script = task.data.script
data_script.diff = script_code
data_script.requirements = {'pip': requirements_txt}
data_script.requirements = {'pip': requirements_txt, 'conda': conda_requirements}
task._update_script(script=data_script)
# update requirements
task._update_requirements(requirements=requirements_txt)
@ -491,8 +529,13 @@ class ScriptInfo(object):
_log("no info for {}", script_dir)
repo_root = repo_info.root or script_dir
working_dir = cls._get_working_dir(repo_root)
entry_point = cls._get_entry_point(repo_root, script_path)
if not plugin:
working_dir = '.'
entry_point = str(script_path.name)
else:
working_dir = cls._get_working_dir(repo_root)
entry_point = cls._get_entry_point(repo_root, script_path)
if check_uncommitted:
diff = cls._get_script_code(script_path.as_posix()) \
if not plugin or not repo_info.commit else repo_info.diff
@ -500,16 +543,18 @@ class ScriptInfo(object):
diff = ''
# if this is not jupyter, get the requirements.txt
requirements = ''
conda_requirements = ''
# create requirements if backend supports requirements
# if jupyter is present, requirements will be created in the background, when saving a snapshot
if not jupyter_filepath and Session.check_min_api_version('2.2'):
script_requirements = ScriptRequirements(
Path(repo_root).as_posix() if repo_info.url else script_path.as_posix())
if create_requirements:
requirements = script_requirements.get_requirements()
requirements, conda_requirements = script_requirements.get_requirements()
else:
script_requirements = None
script_info = dict(
repository=furl(repo_info.url).remove(username=True, password=True).tostr(),
branch=repo_info.branch,
@ -517,7 +562,7 @@ class ScriptInfo(object):
entry_point=entry_point,
working_dir=working_dir,
diff=diff,
requirements={'pip': requirements} if requirements else None,
requirements={'pip': requirements, 'conda': conda_requirements} if requirements else None,
)
messages = []
@ -545,6 +590,10 @@ class ScriptInfo(object):
log.warning("Failed auto-detecting task repository: {}".format(ex))
return ScriptInfoResult(), None
@classmethod
def close(cls):
_JupyterObserver.close()
@attr.s
class ScriptInfoResult(object):