Improve module requirements detection

This commit is contained in:
allegroai 2020-04-26 23:10:45 +03:00
parent 9726f782f2
commit dbb3346332
2 changed files with 37 additions and 9 deletions

View File

@ -151,6 +151,12 @@ sdk {
# Default Task output_uri. if output_uri is not provided to Task.init, default_output_uri will be used instead.
default_output_uri: ""
# Default auto generated requirements optimize for smaller requirements
# If True, analyze the entire repository regardless of the entry point.
# If False, first analyze the entry point script, if it does not contain other to local files,
# do not analyze the entire repository.
force_analyze_entire_repo: false
# Development mode worker
worker {
# Status report period in seconds

View File

@ -12,6 +12,7 @@ from .modules import ReqsModules
class GenerateReqs(object):
_force_modules_reqs = dict()
def __init__(self, save_path, project_path, ignores,
installed_pkgs, comparison_operator='=='):
@ -30,18 +31,34 @@ class GenerateReqs(object):
guess = ReqsModules()
local = ReqsModules()
# make the entry point absolute (relative to the root path)
num_local_mod = 0
if self.__module__:
# create a copy, do not change the class set
our_module = self.__module__.split('.')[0]
if our_module and our_module not in self._force_modules_reqs:
from ...version import __version__
self._force_modules_reqs[our_module] = __version__
# make the entry point absolute (relative o the root path)
if entry_point_filename and not os.path.isabs(entry_point_filename):
entry_point_filename = os.path.join(self._project_path, entry_point_filename) \
if os.path.isdir(self._project_path) else None
# check if the entry point script is self contained, i.e. does not use the rest of the project
if entry_point_filename and os.path.isfile(entry_point_filename):
if entry_point_filename and os.path.isfile(entry_point_filename) and not self._local_mods:
modules, try_imports, local_mods = project_import_modules(entry_point_filename, self._ignores)
if not local_mods:
# update the self._local_mods
self._filter_modules(modules, local_mods)
# check how many local modules we have, excluding ourselves
num_local_mod = len(set(self._local_mods.keys()) - set(self._force_modules_reqs.keys()))
# if we have any module/package we cannot find, take no chances and scan the entire project
if try_imports or local_mods:
# if we have local modules and they are not just us.
if num_local_mod or local_mods:
modules, try_imports, local_mods = project_import_modules(
self._project_path, self._ignores)
else:
modules, try_imports, local_mods = project_import_modules(
self._project_path, self._ignores)
@ -53,9 +70,7 @@ class GenerateReqs(object):
candidates = self._filter_modules(modules, local_mods)
# make sure we are in candidates
ourselves = self.__module__.split('.')[0]if self.__module__ else None
if ourselves and ourselves not in candidates:
candidates.add(ourselves)
candidates |= set(self._force_modules_reqs.keys())
logger.info('Check module in local environment.')
for name in candidates:
@ -69,9 +84,8 @@ class GenerateReqs(object):
# add local modules, so we know what is used but not installed.
for name in self._local_mods:
if name in modules:
if ourselves and name == ourselves:
from ...version import __version__
reqs.add(name, __version__, modules[name])
if name in self._force_modules_reqs:
reqs.add(name, self._force_modules_reqs[name], modules[name])
continue
# if this is a folder of our project, we can safely ignore it
@ -87,6 +101,14 @@ class GenerateReqs(object):
return reqs, try_imports, guess, local
@classmethod
def get_forced_modules(cls):
return cls._force_modules_reqs
@classmethod
def add_forced_module(cls, module_name, module_version):
cls._force_modules_reqs[module_name] = module_version
def _write_reqs(self, reqs):
print('Writing requirements to "{0}"'.format(
self._save_path))