Improve conda support

This commit is contained in:
allegroai 2020-01-21 16:32:57 +02:00
parent 9a3e130700
commit c5dd762d9b

View File

@ -5,10 +5,12 @@ from tempfile import mkstemp
import attr import attr
import collections import collections
import logging import logging
import json
from furl import furl from furl import furl
from pathlib2 import Path from pathlib2 import Path
from threading import Thread, Event from threading import Thread, Event
from .util import get_command_output
from ....backend_api import Session from ....backend_api import Session
from ....debugging import get_logger from ....debugging import get_logger
from .detectors import GitEnvDetector, GitDetector, HgEnvDetector, HgDetector, Result as DetectionResult from .detectors import GitEnvDetector, GitDetector, HgEnvDetector, HgDetector, Result as DetectionResult
@ -63,7 +65,7 @@ class ScriptRequirements(object):
reqs, try_imports, guess = gr.extract_reqs() reqs, try_imports, guess = gr.extract_reqs()
return self.create_requirements_txt(reqs) return self.create_requirements_txt(reqs)
except Exception: except Exception:
return '' return '', ''
@staticmethod @staticmethod
def _patched_project_import_modules(project_path, ignores): def _patched_project_import_modules(project_path, ignores):
@ -115,6 +117,10 @@ class ScriptRequirements(object):
modules |= fmodules modules |= fmodules
try_imports |= try_ipts try_imports |= try_ipts
return ScriptRequirements.add_trains_used_packages(modules), try_imports, local_mods
@staticmethod
def add_trains_used_packages(modules):
# hack: forcefully insert storage modules if we have them # hack: forcefully insert storage modules if we have them
# noinspection PyBroadException # noinspection PyBroadException
try: try:
@ -146,11 +152,34 @@ class ScriptRequirements(object):
except Exception: except Exception:
pass pass
return modules, try_imports, local_mods return modules
@staticmethod @staticmethod
def create_requirements_txt(reqs): def create_requirements_txt(reqs):
# write requirements.txt # write requirements.txt
try:
conda_requirements = ''
conda_prefix = os.environ.get('CONDA_PREFIX')
if conda_prefix and not conda_prefix.endswith(os.path.sep):
conda_prefix += os.path.sep
if conda_prefix and sys.executable.startswith(conda_prefix):
conda_packages_json = get_command_output(['conda', 'list', '--json'])
conda_packages_json = json.loads(conda_packages_json)
reqs_lower = {k.lower(): (k, v) for k, v in reqs.items()}
for r in conda_packages_json:
# check if this is a pypi package, if it is, leave it outside
if not r.get('channel') or r.get('channel') == 'pypi':
continue
# check if we have it in our required packages
name = r['name'].lower().replace('-', '_')
# hack support pytorch/torch different naming convention
if name == 'pytorch':
name = 'torch'
k, v = reqs_lower.get(name, (None, None))
if k:
conda_requirements += '{0} {1} {2}\n'.format(k, '==', v.version)
except:
conda_requirements = ''
# python version header # python version header
requirements_txt = '# Python ' + sys.version.replace('\n', ' ').replace('\r', ' ') + '\n' requirements_txt = '# Python ' + sys.version.replace('\n', ' ').replace('\r', ' ') + '\n'
@ -178,7 +207,7 @@ class ScriptRequirements(object):
requirements_txt += '# IMPORT PACKAGE {0}\n'.format(k) requirements_txt += '# IMPORT PACKAGE {0}\n'.format(k)
requirements_txt += ''.join(['# {0}\n'.format(c) for c in v.comments.sorted_items()]) requirements_txt += ''.join(['# {0}\n'.format(c) for c in v.comments.sorted_items()])
return requirements_txt return requirements_txt, conda_requirements
class _JupyterObserver(object): class _JupyterObserver(object):
@ -206,6 +235,15 @@ class _JupyterObserver(object):
def signal_sync(cls, *_): def signal_sync(cls, *_):
cls._sync_event.set() cls._sync_event.set()
@classmethod
def close(cls):
if not cls._thread:
return
cls._exit_event.set()
cls._sync_event.set()
cls._thread.join()
cls._thread = None
@classmethod @classmethod
def _daemon(cls, jupyter_notebook_filename): def _daemon(cls, jupyter_notebook_filename):
from trains import Task from trains import Task
@ -245,13 +283,11 @@ class _JupyterObserver(object):
last_update_ts = None last_update_ts = None
counter = 0 counter = 0
prev_script_hash = None prev_script_hash = None
# main observer loop # main observer loop, check if we need to exit
while True: while not cls._exit_event.wait(timeout=0.):
# wait for timeout or sync event # wait for timeout or sync event
cls._sync_event.wait(cls._sample_frequency if counter else cls._first_sample_frequency) cls._sync_event.wait(cls._sample_frequency if counter else cls._first_sample_frequency)
# check if we need to exit
if cls._exit_event.wait(timeout=0.):
return
cls._sync_event.clear() cls._sync_event.clear()
counter += 1 counter += 1
# noinspection PyBroadException # noinspection PyBroadException
@ -283,23 +319,25 @@ class _JupyterObserver(object):
if prev_script_hash and prev_script_hash == current_script_hash: if prev_script_hash and prev_script_hash == current_script_hash:
continue continue
requirements_txt = '' requirements_txt = ''
conda_requirements = ''
# parse jupyter python script and prepare pip requirements (pigar) # parse jupyter python script and prepare pip requirements (pigar)
# if backend supports requirements # if backend supports requirements
if file_import_modules and Session.check_min_api_version('2.2'): if file_import_modules and Session.check_min_api_version('2.2'):
fmodules, _ = file_import_modules(notebook.parts[-1], script_code) fmodules, _ = file_import_modules(notebook.parts[-1], script_code)
fmodules = ScriptRequirements.add_trains_used_packages(fmodules)
installed_pkgs = get_installed_pkgs_detail() installed_pkgs = get_installed_pkgs_detail()
reqs = ReqsModules() reqs = ReqsModules()
for name in fmodules: for name in fmodules:
if name in installed_pkgs: if name in installed_pkgs:
pkg_name, version = installed_pkgs[name] pkg_name, version = installed_pkgs[name]
reqs.add(pkg_name, version, fmodules[name]) reqs.add(pkg_name, version, fmodules[name])
requirements_txt = ScriptRequirements.create_requirements_txt(reqs) requirements_txt, conda_requirements = ScriptRequirements.create_requirements_txt(reqs)
# update script # update script
prev_script_hash = current_script_hash prev_script_hash = current_script_hash
data_script = task.data.script data_script = task.data.script
data_script.diff = script_code data_script.diff = script_code
data_script.requirements = {'pip': requirements_txt} data_script.requirements = {'pip': requirements_txt, 'conda': conda_requirements}
task._update_script(script=data_script) task._update_script(script=data_script)
# update requirements # update requirements
task._update_requirements(requirements=requirements_txt) task._update_requirements(requirements=requirements_txt)
@ -491,8 +529,13 @@ class ScriptInfo(object):
_log("no info for {}", script_dir) _log("no info for {}", script_dir)
repo_root = repo_info.root or script_dir repo_root = repo_info.root or script_dir
working_dir = cls._get_working_dir(repo_root) if not plugin:
entry_point = cls._get_entry_point(repo_root, script_path) working_dir = '.'
entry_point = str(script_path.name)
else:
working_dir = cls._get_working_dir(repo_root)
entry_point = cls._get_entry_point(repo_root, script_path)
if check_uncommitted: if check_uncommitted:
diff = cls._get_script_code(script_path.as_posix()) \ diff = cls._get_script_code(script_path.as_posix()) \
if not plugin or not repo_info.commit else repo_info.diff if not plugin or not repo_info.commit else repo_info.diff
@ -500,16 +543,18 @@ class ScriptInfo(object):
diff = '' diff = ''
# if this is not jupyter, get the requirements.txt # if this is not jupyter, get the requirements.txt
requirements = '' requirements = ''
conda_requirements = ''
# create requirements if backend supports requirements # create requirements if backend supports requirements
# if jupyter is present, requirements will be created in the background, when saving a snapshot # if jupyter is present, requirements will be created in the background, when saving a snapshot
if not jupyter_filepath and Session.check_min_api_version('2.2'): if not jupyter_filepath and Session.check_min_api_version('2.2'):
script_requirements = ScriptRequirements( script_requirements = ScriptRequirements(
Path(repo_root).as_posix() if repo_info.url else script_path.as_posix()) Path(repo_root).as_posix() if repo_info.url else script_path.as_posix())
if create_requirements: if create_requirements:
requirements = script_requirements.get_requirements() requirements, conda_requirements = script_requirements.get_requirements()
else: else:
script_requirements = None script_requirements = None
script_info = dict( script_info = dict(
repository=furl(repo_info.url).remove(username=True, password=True).tostr(), repository=furl(repo_info.url).remove(username=True, password=True).tostr(),
branch=repo_info.branch, branch=repo_info.branch,
@ -517,7 +562,7 @@ class ScriptInfo(object):
entry_point=entry_point, entry_point=entry_point,
working_dir=working_dir, working_dir=working_dir,
diff=diff, diff=diff,
requirements={'pip': requirements} if requirements else None, requirements={'pip': requirements, 'conda': conda_requirements} if requirements else None,
) )
messages = [] messages = []
@ -545,6 +590,10 @@ class ScriptInfo(object):
log.warning("Failed auto-detecting task repository: {}".format(ex)) log.warning("Failed auto-detecting task repository: {}".format(ex))
return ScriptInfoResult(), None return ScriptInfoResult(), None
@classmethod
def close(cls):
_JupyterObserver.close()
@attr.s @attr.s
class ScriptInfoResult(object): class ScriptInfoResult(object):