Add Task.ignore_requirements

This commit is contained in:
allegroai 2021-06-01 00:19:06 +03:00
parent d22cf7557d
commit e10940d038
2 changed files with 24 additions and 5 deletions

View File

@ -27,7 +27,7 @@ class ScriptInfoError(Exception):
class ScriptRequirements(object):
_max_requirements_size = 512 * 1024
_packages_remove_version = ('setuptools', )
_ignore_packages = ('pywin32',)
_ignore_packages = set()
def __init__(self, root_folder):
self._root_folder = root_folder
@ -82,7 +82,7 @@ class ScriptRequirements(object):
# if we have torch and it supports tensorboard, we should add that as well
# (because it will not be detected automatically)
if 'torch' in modules and 'tensorboard' not in modules:
if 'torch' in modules and 'tensorboard' not in modules and 'tensorboardX' not in modules:
# noinspection PyBroadException
try:
# see if this version of torch support tensorboard
@ -149,13 +149,17 @@ class ScriptRequirements(object):
conda_requirements = ''
# add forced requirements:
forced_packages = {}
ignored_packages = ScriptRequirements._ignore_packages
# noinspection PyBroadException
try:
from ..task import Task
# noinspection PyProtectedMember
forced_packages = copy(Task._force_requirements)
# noinspection PyProtectedMember
ignored_packages = Task._ignore_requirements | ignored_packages
except Exception:
forced_packages = {}
pass
# python version header
requirements_txt = '# Python ' + sys.version.replace('\n', ' ').replace('\r', ' ') + '\n'
@ -171,7 +175,7 @@ class ScriptRequirements(object):
# requirement summary
requirements_txt += '\n'
for k, v in reqs.sorted_items():
if k.lower() in ScriptRequirements._ignore_packages:
if k in ignored_packages or k.lower() in ignored_packages:
continue
version = v.version
if k in forced_packages:

View File

@ -67,6 +67,7 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
_default_configuration_section_name = 'General'
_legacy_parameters_section_name = 'Args'
_force_requirements = {}
_ignore_requirements = set()
_store_diff = config.get('development.store_uncommitted_code_diff', False)
_store_remote_diff = config.get('development.store_code_diff_from_remote', False)
@ -1679,7 +1680,21 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
if not running_remotely() and hasattr(cls, 'current_task') and cls.current_task():
get_logger('task').warning(
'Requirement ignored, Task.add_requirements() must be called before Task.init()')
cls._force_requirements[package_name] = package_version
cls._force_requirements[str(package_name)] = package_version
@classmethod
def ignore_requirements(cls, package_name):
# type: (str) -> None
"""
Ignore a specific package when auto generating the requirements list.
Example: Task.ignore_requirements('pywin32')
:param str package_name: The package name to remove/ignore from the "Installed Packages" section of the task.
"""
if not running_remotely() and hasattr(cls, 'current_task') and cls.current_task():
get_logger('task').warning(
'Requirement ignored, Task.ignore_requirements() must be called before Task.init()')
cls._ignore_requirements.add(str(package_name))
@classmethod
def force_requirements_env_freeze(cls, force=True):