diff --git a/docs/trains.conf b/docs/trains.conf index 2a8f9efa..da0601c4 100644 --- a/docs/trains.conf +++ b/docs/trains.conf @@ -151,6 +151,12 @@ sdk { # Default Task output_uri. if output_uri is not provided to Task.init, default_output_uri will be used instead. default_output_uri: "" + # Default auto generated requirements optimize for smaller requirements + # If True, analyze the entire repository regardless of the entry point. + # If False, first analyze the entry point script, if it does not contain other to local files, + # do not analyze the entire repository. + force_analyze_entire_repo: false + # Development mode worker worker { # Status report period in seconds diff --git a/trains/utilities/pigar/__main__.py b/trains/utilities/pigar/__main__.py index 42ca3683..506f249c 100644 --- a/trains/utilities/pigar/__main__.py +++ b/trains/utilities/pigar/__main__.py @@ -12,6 +12,7 @@ from .modules import ReqsModules class GenerateReqs(object): + _force_modules_reqs = dict() def __init__(self, save_path, project_path, ignores, installed_pkgs, comparison_operator='=='): @@ -30,18 +31,34 @@ class GenerateReqs(object): guess = ReqsModules() local = ReqsModules() - # make the entry point absolute (relative to the root path) + num_local_mod = 0 + if self.__module__: + # create a copy, do not change the class set + our_module = self.__module__.split('.')[0] + if our_module and our_module not in self._force_modules_reqs: + from ...version import __version__ + self._force_modules_reqs[our_module] = __version__ + + # make the entry point absolute (relative o the root path) if entry_point_filename and not os.path.isabs(entry_point_filename): entry_point_filename = os.path.join(self._project_path, entry_point_filename) \ if os.path.isdir(self._project_path) else None # check if the entry point script is self contained, i.e. does not use the rest of the project - if entry_point_filename and os.path.isfile(entry_point_filename): + if entry_point_filename and os.path.isfile(entry_point_filename) and not self._local_mods: modules, try_imports, local_mods = project_import_modules(entry_point_filename, self._ignores) + if not local_mods: + # update the self._local_mods + self._filter_modules(modules, local_mods) + # check how many local modules we have, excluding ourselves + num_local_mod = len(set(self._local_mods.keys()) - set(self._force_modules_reqs.keys())) + # if we have any module/package we cannot find, take no chances and scan the entire project - if try_imports or local_mods: + # if we have local modules and they are not just us. + if num_local_mod or local_mods: modules, try_imports, local_mods = project_import_modules( self._project_path, self._ignores) + else: modules, try_imports, local_mods = project_import_modules( self._project_path, self._ignores) @@ -53,9 +70,7 @@ class GenerateReqs(object): candidates = self._filter_modules(modules, local_mods) # make sure we are in candidates - ourselves = self.__module__.split('.')[0]if self.__module__ else None - if ourselves and ourselves not in candidates: - candidates.add(ourselves) + candidates |= set(self._force_modules_reqs.keys()) logger.info('Check module in local environment.') for name in candidates: @@ -69,9 +84,8 @@ class GenerateReqs(object): # add local modules, so we know what is used but not installed. for name in self._local_mods: if name in modules: - if ourselves and name == ourselves: - from ...version import __version__ - reqs.add(name, __version__, modules[name]) + if name in self._force_modules_reqs: + reqs.add(name, self._force_modules_reqs[name], modules[name]) continue # if this is a folder of our project, we can safely ignore it @@ -87,6 +101,14 @@ class GenerateReqs(object): return reqs, try_imports, guess, local + @classmethod + def get_forced_modules(cls): + return cls._force_modules_reqs + + @classmethod + def add_forced_module(cls, module_name, module_version): + cls._force_modules_reqs[module_name] = module_version + def _write_reqs(self, reqs): print('Writing requirements to "{0}"'.format( self._save_path))