Add Task.add_requirements specify version <>=~ etc.

This commit is contained in:
allegroai 2021-02-27 23:51:40 +02:00
parent e24a421457
commit 7e6158dd9b
2 changed files with 22 additions and 14 deletions

View File

@ -172,24 +172,14 @@ class ScriptRequirements(object):
version = v.version
if k in forced_packages:
forced_version = forced_packages.pop(k, None)
if forced_version:
if forced_version is not None:
version = forced_version
# requirements_txt += ''.join(['# {0}\n'.format(c) for c in v.comments.sorted_items()])
if k == '-e' and version:
requirements_txt += '{0}\n'.format(version)
elif k.startswith('-e '):
requirements_txt += '{0} {1}\n'.format(k.replace('-e ', '', 1), version or '')
elif version:
requirements_txt += '{0} {1} {2}\n'.format(k, '==', version)
else:
requirements_txt += '{0}\n'.format(k)
requirements_txt += ScriptRequirements._make_req_line(k, version)
# add forced requirements that we could not find installed on the system
for k in sorted(forced_packages.keys()):
if forced_packages[k]:
requirements_txt += '{0} {1} {2}\n'.format(k, '==', forced_packages[k])
else:
requirements_txt += '{0}\n'.format(k)
requirements_txt += ScriptRequirements._make_req_line(k, forced_packages.get(k))
requirements_txt_packages_only = \
requirements_txt + '\n# Skipping detailed import analysis, it is too large\n'
@ -218,6 +208,21 @@ class ScriptRequirements(object):
else requirements_txt_packages_only,
conda_requirements)
@staticmethod
def _make_req_line(k, version):
requirements_txt = ''
if k == '-e' and version:
requirements_txt += '{0}\n'.format(version)
elif k.startswith('-e '):
requirements_txt += '{0} {1}\n'.format(k.replace('-e ', '', 1), version or '')
elif version and (str(version).strip() or ' ')[0] in '><~=':
requirements_txt += '{0} {1}\n'.format(k, version)
elif version and str(version).strip():
requirements_txt += '{0} {1} {2}\n'.format(k, '==', version)
else:
requirements_txt += '{0}\n'.format(k)
return requirements_txt
@staticmethod
def _remove_package_versions(installed_pkgs, package_names_to_remove_version):
installed_pkgs = {k: (v[0], None if str(k) in package_names_to_remove_version else v[1])

View File

@ -1612,9 +1612,12 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
def add_requirements(cls, package_name, package_version=None):
# type: (str, Optional[str]) -> ()
"""
Force the adding of a package to the requirements list. If ``package_version`` is not specified, use the
Force the adding of a package to the requirements list. If ``package_version`` is None, use the
installed package version, if found.
Example: Task.add_requirements('tensorflow', '2.4.0')
Example: Task.add_requirements('tensorflow', '>=2.4')
Example: Task.add_requirements('tensorflow') -> use the installed tensorflow version
Example: Task.add_requirements('tensorflow', '') -> no version limit
:param str package_name: The package name to add to the "Installed Packages" section of the task.
:param package_version: The package version requirements. If ``None``, then use the installed version.