Add Task.add_requirements() to force requirement package into "installed packages"

This commit is contained in:
allegroai 2020-05-22 10:35:27 +03:00
parent 072abfd6fd
commit 0e2265a9ca
2 changed files with 33 additions and 2 deletions

View File

@ -1,5 +1,6 @@
import os import os
import sys import sys
from copy import copy
from tempfile import mkstemp from tempfile import mkstemp
import attr import attr
@ -80,6 +81,15 @@ class ScriptRequirements(object):
except Exception: except Exception:
pass pass
# add forced requirements:
# noinspection PyBroadException
try:
from ..task import Task
for package, version in Task._force_requirements.items():
modules.add(package, 'trains', 0)
except Exception:
pass
return modules return modules
@staticmethod @staticmethod
@ -109,6 +119,14 @@ class ScriptRequirements(object):
except: except:
conda_requirements = '' conda_requirements = ''
# add forced requirements:
# noinspection PyBroadException
try:
from ..task import Task
forced_packages = copy(Task._force_requirements)
except Exception:
forced_packages = {}
# python version header # python version header
requirements_txt = '# Python ' + sys.version.replace('\n', ' ').replace('\r', ' ') + '\n' requirements_txt = '# Python ' + sys.version.replace('\n', ' ').replace('\r', ' ') + '\n'
@ -120,11 +138,23 @@ class ScriptRequirements(object):
# requirement summary # requirement summary
requirements_txt += '\n' requirements_txt += '\n'
for k, v in reqs.sorted_items(): for k, v in reqs.sorted_items():
version = v.version
if k in forced_packages:
forced_version = forced_packages.pop(k, None)
if forced_version:
version = forced_version
# requirements_txt += ''.join(['# {0}\n'.format(c) for c in v.comments.sorted_items()]) # requirements_txt += ''.join(['# {0}\n'.format(c) for c in v.comments.sorted_items()])
if k == '-e': if k == '-e':
requirements_txt += '{0} {1}\n'.format(k, v.version) requirements_txt += '{0} {1}\n'.format(k, version)
elif v: elif v:
requirements_txt += '{0} {1} {2}\n'.format(k, '==', v.version) requirements_txt += '{0} {1} {2}\n'.format(k, '==', version)
else:
requirements_txt += '{0}\n'.format(k)
# add forced requirements that we could not find installed on the system
for k in sorted(forced_packages.keys()):
if forced_packages[k]:
requirements_txt += '{0} {1} {2}\n'.format(k, '==', forced_packages[k])
else: else:
requirements_txt += '{0}\n'.format(k) requirements_txt += '{0}\n'.format(k)

View File

@ -48,6 +48,7 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
_anonymous_dataview_id = '__anonymous__' _anonymous_dataview_id = '__anonymous__'
_development_tag = 'development' _development_tag = 'development'
_force_requirements = {}
_store_diff = config.get('development.store_uncommitted_code_diff', False) _store_diff = config.get('development.store_uncommitted_code_diff', False)