diff --git a/requirements.txt b/requirements.txt index 89987bc0..ec9e41d1 100644 --- a/requirements.txt +++ b/requirements.txt @@ -11,7 +11,6 @@ jsonschema>=2.6.0 numpy>=1.10 pathlib2>=2.3.0 Pillow>=4.1.1 -pigar==0.9.2 plotly>=3.9.0 psutil>=3.4.2 pyparsing>=2.0.3 diff --git a/trains/backend_interface/task/repo/scriptinfo.py b/trains/backend_interface/task/repo/scriptinfo.py index b2e12ed2..42817a01 100644 --- a/trains/backend_interface/task/repo/scriptinfo.py +++ b/trains/backend_interface/task/repo/scriptinfo.py @@ -26,103 +26,18 @@ class ScriptRequirements(object): def __init__(self, root_folder): self._root_folder = root_folder - @staticmethod - def get_installed_pkgs_detail(reqs): - """ - HACK: bugfix of the original pigar get_installed_pkgs_detail - - Get mapping for import top level name - and install package name with version. - """ - mapping = dict() - - for path in sys.path: - if os.path.isdir(path) and path.rstrip('/').endswith( - ('site-packages', 'dist-packages')): - new_mapping = reqs._search_path(path) - # BUGFIX: - # override with previous, just like python resolves imports, the first match is the one used. - # unlike the original implementation, where the last one is used. - new_mapping.update(mapping) - mapping = new_mapping - - # HACK: prefer tensorflow_gpu over tensorflow - if 'tensorflow_gpu' in new_mapping: - new_mapping['tensorflow'] = new_mapping['tensorflow_gpu'] - - return mapping - def get_requirements(self): try: - from pigar import reqs - reqs.project_import_modules = ScriptRequirements._patched_project_import_modules - from pigar.__main__ import GenerateReqs - from pigar.log import logger - logger.setLevel(logging.WARNING) - try: - # first try our version, if we fail revert to the internal implantation - installed_pkgs = self.get_installed_pkgs_detail(reqs) - except Exception: - installed_pkgs = reqs.get_installed_pkgs_detail() + from ....utilities.pigar.reqs import get_installed_pkgs_detail + from ....utilities.pigar.__main__ import GenerateReqs + installed_pkgs = get_installed_pkgs_detail() gr = GenerateReqs(save_path='', project_path=self._root_folder, installed_pkgs=installed_pkgs, ignores=['.git', '.hg', '.idea', '__pycache__', '.ipynb_checkpoints']) - reqs, try_imports, guess = gr.extract_reqs() - return self.create_requirements_txt(reqs) + reqs, try_imports, guess, local_pks = gr.extract_reqs(module_callback=ScriptRequirements.add_trains_used_packages) + return self.create_requirements_txt(reqs, local_pks) except Exception: return '', '' - @staticmethod - def _patched_project_import_modules(project_path, ignores): - """ - copied form pigar req.project_import_modules - patching, os.getcwd() is incorrectly used - """ - from pigar.modules import ImportedModules - from pigar.reqs import file_import_modules - modules = ImportedModules() - try_imports = set() - local_mods = list() - ignore_paths = collections.defaultdict(set) - if not ignores: - ignore_paths[project_path].add('.git') - else: - for path in ignores: - parent_dir = os.path.dirname(path) - ignore_paths[parent_dir].add(os.path.basename(path)) - - if os.path.isfile(project_path): - fake_path = Path(project_path).name - with open(project_path, 'rb') as f: - fmodules, try_ipts = file_import_modules(fake_path, f.read()) - modules |= fmodules - try_imports |= try_ipts - else: - cur_dir = project_path # os.getcwd() - for dirpath, dirnames, files in os.walk(project_path, followlinks=True): - if dirpath in ignore_paths: - dirnames[:] = [d for d in dirnames - if d not in ignore_paths[dirpath]] - py_files = list() - for fn in files: - # C extension. - if fn.endswith('.so'): - local_mods.append(fn[:-3]) - # Normal Python file. - if fn.endswith('.py'): - local_mods.append(fn[:-3]) - py_files.append(fn) - if '__init__.py' in files: - local_mods.append(os.path.basename(dirpath)) - for file in py_files: - fpath = os.path.join(dirpath, file) - fake_path = fpath.split(cur_dir)[1][1:] - with open(fpath, 'rb') as f: - fmodules, try_ipts = file_import_modules(fake_path, f.read()) - modules |= fmodules - try_imports |= try_ipts - - return ScriptRequirements.add_trains_used_packages(modules), try_imports, local_mods - @staticmethod def add_trains_used_packages(modules): # hack: forcefully insert storage modules if we have them @@ -159,7 +74,7 @@ class ScriptRequirements(object): return modules @staticmethod - def create_requirements_txt(reqs): + def create_requirements_txt(reqs, local_pks=None): # write requirements.txt try: conda_requirements = '' @@ -188,6 +103,11 @@ class ScriptRequirements(object): # python version header requirements_txt = '# Python ' + sys.version.replace('\n', ' ').replace('\r', ' ') + '\n' + if local_pks: + requirements_txt += '\n# Local modules found - skipping:\n' + for k, v in local_pks.sorted_items(): + requirements_txt += '# {0} == {1}\n'.format(k, v.version) + # requirement summary requirements_txt += '\n' for k, v in reqs.sorted_items(): @@ -203,6 +123,13 @@ class ScriptRequirements(object): requirements_txt += '\n' + \ '# Detailed import analysis\n' \ '# **************************\n' + + if local_pks: + for k, v in local_pks.sorted_items(): + requirements_txt += '\n' + requirements_txt += '# IMPORT LOCAL PACKAGE {0}\n'.format(k) + requirements_txt += ''.join(['# {0}\n'.format(c) for c in v.comments.sorted_items()]) + for k, v in reqs.sorted_items(): requirements_txt += '\n' if k == '-e': @@ -262,9 +189,9 @@ class _JupyterObserver(object): # load pigar # noinspection PyBroadException try: - from pigar.reqs import get_installed_pkgs_detail, file_import_modules - from pigar.modules import ReqsModules - from pigar.log import logger + from ....utilities.pigar.reqs import get_installed_pkgs_detail, file_import_modules + from ....utilities.pigar.modules import ReqsModules + from ....utilities.pigar.log import logger logger.setLevel(logging.WARNING) except Exception: file_import_modules = None diff --git a/trains/utilities/pigar/__init__.py b/trains/utilities/pigar/__init__.py new file mode 100644 index 00000000..40a96afc --- /dev/null +++ b/trains/utilities/pigar/__init__.py @@ -0,0 +1 @@ +# -*- coding: utf-8 -*- diff --git a/trains/utilities/pigar/__main__.py b/trains/utilities/pigar/__main__.py new file mode 100644 index 00000000..7508b799 --- /dev/null +++ b/trains/utilities/pigar/__main__.py @@ -0,0 +1,133 @@ +# -*- coding: utf-8 -*- + +from __future__ import print_function, division, absolute_import + +import os +import codecs + +from .reqs import project_import_modules, is_std_or_local_lib +from .utils import lines_diff +from .log import logger +from .modules import ReqsModules + + +class GenerateReqs(object): + + def __init__(self, save_path, project_path, ignores, + installed_pkgs, comparison_operator='=='): + self._save_path = save_path + self._project_path = project_path + self._ignores = ignores + self._installed_pkgs = installed_pkgs + self._maybe_local_mods = set() + self._local_mods = dict() + self._comparison_operator = comparison_operator + + def extract_reqs(self, module_callback=None): + """Extract requirements from project.""" + + reqs = ReqsModules() + guess = ReqsModules() + local = ReqsModules() + modules, try_imports, local_mods = project_import_modules( + self._project_path, self._ignores) + if module_callback: + modules = module_callback(modules) + app_name = os.path.basename(self._project_path) + if app_name in local_mods: + local_mods.remove(app_name) + + # Filtering modules + candidates = self._filter_modules(modules, local_mods) + + logger.info('Check module in local environment.') + for name in candidates: + logger.info('Checking module: %s', name) + if name in self._installed_pkgs: + pkg_name, version = self._installed_pkgs[name] + reqs.add(pkg_name, version, modules[name]) + else: + guess.add(name, 0, modules[name]) + + # add local modules, so we know what is used but not installed. + for name in self._local_mods: + if name in modules: + relpath = os.path.relpath(self._local_mods[name], self._project_path) + if not relpath.startswith('.'): + relpath = '.' + os.path.sep + relpath + local.add(name, relpath, modules[name]) + + return reqs, try_imports, guess, local + + def _write_reqs(self, reqs): + print('Writing requirements to "{0}"'.format( + self._save_path)) + with open(self._save_path, 'w+') as f: + f.write('# Requirements automatically generated by pigar.\n' + '# https://github.com/damnever/pigar\n') + for k, v in reqs.sorted_items(): + f.write('\n') + f.write(''.join(['# {0}\n'.format(c) + for c in v.comments.sorted_items()])) + if k == '-e': + f.write('{0} {1}\n'.format(k, v.version)) + elif v: + f.write('{0} {1} {2}\n'.format( + k, self._comparison_operator, v.version)) + else: + f.write('{0}\n'.format(k)) + + def _best_matchs(self, name, pkgs): + # If imported name equals to package name. + if name in pkgs: + return [pkgs[pkgs.index(name)]] + # If not, return all possible packages. + return pkgs + + def _filter_modules(self, modules, local_mods): + candidates = set() + + logger.info('Filtering modules ...') + for module in modules: + logger.info('Checking module: %s', module) + if not module or module.startswith('.'): + continue + if module in local_mods: + self._maybe_local_mods.add(module) + module_std_local = is_std_or_local_lib(module) + if module_std_local is True: + continue + if isinstance(module_std_local, str): + self._local_mods[module] = module_std_local + continue + candidates.add(module) + + return candidates + + def _invalid_reqs(self, reqs): + for name, detail in reqs.sorted_items(): + print( + ' {0} referenced from:\n {1}'.format( + name, + '\n '.join(detail.comments.sorted_items()) + ) + ) + + def _save_old_reqs(self): + if os.path.isfile(self._save_path): + with codecs.open(self._save_path, 'rb', 'utf-8') as f: + self._old_reqs = f.readlines() + + def _reqs_diff(self): + if not hasattr(self, '_old_reqs'): + return + with codecs.open(self._save_path, 'rb', 'utf-8') as f: + new_reqs = f.readlines() + is_diff, diffs = lines_diff(self._old_reqs, new_reqs) + msg = 'Requirements file has been covered, ' + if is_diff: + msg += 'there is the difference:' + print('{0}\n{1}'.format(msg, ''.join(diffs)), end='') + else: + msg += 'no difference.' + print(msg) diff --git a/trains/utilities/pigar/extractor/__init__.py b/trains/utilities/pigar/extractor/__init__.py new file mode 100644 index 00000000..b9b0ac03 --- /dev/null +++ b/trains/utilities/pigar/extractor/__init__.py @@ -0,0 +1,12 @@ +# -*- coding: utf-8 -*- + +from __future__ import print_function, division, absolute_import + + +from ..utils import PY32 + + +if PY32: + from .thread_extractor import ThreadExtractor as Extractor +else: + from .gevent_extractor import GeventExtractor as Extractor diff --git a/trains/utilities/pigar/extractor/extractor.py b/trains/utilities/pigar/extractor/extractor.py new file mode 100644 index 00000000..08464f6c --- /dev/null +++ b/trains/utilities/pigar/extractor/extractor.py @@ -0,0 +1,36 @@ +# -*- coding: utf-8 -*- + +from __future__ import print_function, division, absolute_import + +import multiprocessing + + +class BaseExtractor(object): + + def __init__(self, names, max_workers=None): + self._names = names + self._max_workers = max_workers or (multiprocessing.cpu_count() * 4) + + def run(self, job): + try: + self.extract(job) + self.wait_complete() + except KeyboardInterrupt: + print('** Shutting down ...') + self.shutdown() + else: + print('^.^ Extracting all packages done!') + finally: + self.final() + + def extract(self, job): + raise NotImplemented + + def wait_complete(self): + raise NotImplemented + + def shutdown(self): + raise NotImplemented + + def final(self): + pass diff --git a/trains/utilities/pigar/extractor/gevent_extractor.py b/trains/utilities/pigar/extractor/gevent_extractor.py new file mode 100644 index 00000000..ba7363c5 --- /dev/null +++ b/trains/utilities/pigar/extractor/gevent_extractor.py @@ -0,0 +1,50 @@ +# -*- coding: utf-8 -*- + +from __future__ import print_function, division, absolute_import + +import sys + +import greenlet +from gevent.pool import Pool + +from .extractor import BaseExtractor +from ..log import logger + + +class GeventExtractor(BaseExtractor): + + def __init__(self, names, max_workers=222): + super(self.__class__, self).__init__(names, max_workers) + self._pool = Pool(self._max_workers) + self._exited_greenlets = 0 + + def extract(self, job): + job = self._job_wrapper(job) + for name in self._names: + if self._pool.full(): + self._pool.wait_available() + self._pool.spawn(job, name) + + def _job_wrapper(self, job): + def _job(name): + result = None + try: + result = job(name) + except greenlet.GreenletExit: + self._exited_greenlets += 1 + except Exception: + e = sys.exc_info()[1] + logger.error('Extracting "{0}", got: {1}'.format(name, e)) + return result + return _job + + def wait_complete(self): + self._pool.join() + + def shutdown(self): + self._pool.kill(block=True) + + def final(self): + count = self._exited_greenlets + if count != 0: + print('** {0} running job exited.'.format(count)) diff --git a/trains/utilities/pigar/extractor/thread_extractor.py b/trains/utilities/pigar/extractor/thread_extractor.py new file mode 100644 index 00000000..e0a4ce73 --- /dev/null +++ b/trains/utilities/pigar/extractor/thread_extractor.py @@ -0,0 +1,45 @@ +# -*- coding: utf-8 -*- + +from __future__ import print_function, division, absolute_import + +import concurrent.futures + +from .extractor import BaseExtractor + +from ..log import logger + + +class ThreadExtractor(BaseExtractor): + """Extractor use thread pool execute tasks. + + Can be used to extract /simple/ or /pypi//json. + + FIXME: can not deliver SIG_INT to threads in Python 2. + """ + + def __init__(self, names, max_workers=None): + super(self.__class__, self).__init__(names, max_workers) + self._futures = dict() + + def extract(self, job): + """Extract url by package name.""" + with concurrent.futures.ThreadPoolExecutor( + max_workers=self._max_workers) as executor: + for name in self._names: + self._futures[executor.submit(job, name)] = name + + def wait_complete(self): + """Wait for futures complete done.""" + for future in concurrent.futures.as_completed(self._futures.keys()): + try: + error = future.exception() + except concurrent.futures.CancelledError: + break + name = self._futures[future] + if error is not None: + err_msg = 'Extracting "{0}", got: {1}'.format(name, error) + logger.error(err_msg) + + def shutdown(self): + for future in self._futures: + future.cancel() diff --git a/trains/utilities/pigar/log.py b/trains/utilities/pigar/log.py new file mode 100644 index 00000000..57d3ee15 --- /dev/null +++ b/trains/utilities/pigar/log.py @@ -0,0 +1,9 @@ +# -*- coding: utf-8 -*- + +from __future__ import print_function, division, absolute_import + +import logging.handlers + + +logger = logging.getLogger('pigar') +logger.setLevel(logging.WARNING) diff --git a/trains/utilities/pigar/modules.py b/trains/utilities/pigar/modules.py new file mode 100644 index 00000000..da71529a --- /dev/null +++ b/trains/utilities/pigar/modules.py @@ -0,0 +1,111 @@ +# -*- coding: utf-8 -*- + +from __future__ import print_function, division, absolute_import + +import collections + + +# FIXME: Just a workaround, not a radical cure.. +_special_cases = { + "dogpile.cache": "dogpile.cache", + "dogpile.core": "dogpile.core", + "ruamel.yaml": "ruamel.yaml", + "ruamel.ordereddict": "ruamel.ordereddict", +} + + +class Modules(dict): + """Modules object will be used to store modules information.""" + + def __init__(self): + super(Modules, self).__init__() + + +class ImportedModules(Modules): + + def __init__(self): + super(ImportedModules, self).__init__() + + def add(self, name, file, lineno): + if name is None: + return + + names = list() + special_name = '.'.join(name.split('.')[:2]) + # Flask extension. + if name.startswith('flask.ext.'): + names.append('flask') + names.append('flask_' + name.split('.')[2]) + # Special cases.. + elif special_name in _special_cases: + names.append(_special_cases[special_name]) + # Other. + elif '.' in name and not name.startswith('.'): + names.append(name.split('.')[0]) + else: + names.append(name) + + for nm in names: + if nm not in self: + self[nm] = _Locations() + self[nm].add(file, lineno) + + def __or__(self, obj): + for name, locations in obj.items(): + for file, linenos in locations.items(): + for lineno in linenos: + self.add(name, file, lineno) + return self + + +class ReqsModules(Modules): + + _Detail = collections.namedtuple('Detail', ['version', 'comments']) + + def __init__(self): + super(ReqsModules, self).__init__() + self._sorted = None + + def add(self, package, version, locations): + if package in self: + self[package].comments.extend(locations) + else: + self[package] = self._Detail(version, locations) + + def sorted_items(self): + if self._sorted is None: + self._sorted = sorted(self.items()) + return self._sorted + + def remove(self, *names): + for name in names: + if name in self: + self.pop(name) + self._sorted = None + + +class _Locations(dict): + """_Locations store code locations(file, linenos).""" + + def __init__(self): + super(_Locations, self).__init__() + self._sorted = None + + def add(self, file, lineno): + if file in self and lineno not in self[file]: + self[file].append(lineno) + else: + self[file] = [lineno] + + def extend(self, obj): + for file, linenos in obj.items(): + for lineno in linenos: + self.add(file, lineno) + + def sorted_items(self): + if self._sorted is None: + self._sorted = [ + '{0}: {1}'.format(f, ','.join([str(n) for n in sorted(ls)])) + for f, ls in sorted(self.items()) + ] + return self._sorted diff --git a/trains/utilities/pigar/reqs.py b/trains/utilities/pigar/reqs.py new file mode 100644 index 00000000..3de530ca --- /dev/null +++ b/trains/utilities/pigar/reqs.py @@ -0,0 +1,398 @@ +# -*- coding: utf-8 -*- + +from __future__ import print_function, division, absolute_import + +import os +import sys +import fnmatch +import importlib +import imp +import ast +import doctest +import collections +import functools +from pathlib2 import Path +try: + from types import FileType # py2 +except ImportError: + from io import IOBase as FileType # py3 + +from .log import logger +from .utils import parse_git_config +from .modules import ImportedModules + + +def project_import_modules(project_path, ignores): + """ + copied form pigar req.project_import_modules patching, os.getcwd() is incorrectly used + """ + modules = ImportedModules() + try_imports = set() + local_mods = list() + ignore_paths = collections.defaultdict(set) + if not ignores: + ignore_paths[project_path].add('.git') + else: + for path in ignores: + parent_dir = os.path.dirname(path) + ignore_paths[parent_dir].add(os.path.basename(path)) + + if os.path.isfile(project_path): + fake_path = Path(project_path).name + with open(project_path, 'rb') as f: + fmodules, try_ipts = file_import_modules(fake_path, f.read()) + modules |= fmodules + try_imports |= try_ipts + else: + cur_dir = project_path # os.getcwd() + for dirpath, dirnames, files in os.walk(project_path, followlinks=True): + if dirpath in ignore_paths: + dirnames[:] = [d for d in dirnames + if d not in ignore_paths[dirpath]] + py_files = list() + for fn in files: + # C extension. + if fn.endswith('.so'): + local_mods.append(fn[:-3]) + # Normal Python file. + if fn.endswith('.py'): + local_mods.append(fn[:-3]) + py_files.append(fn) + if '__init__.py' in files: + local_mods.append(os.path.basename(dirpath)) + for file in py_files: + fpath = os.path.join(dirpath, file) + fake_path = fpath.split(cur_dir)[1][1:] + with open(fpath, 'rb') as f: + fmodules, try_ipts = file_import_modules(fake_path, f.read()) + modules |= fmodules + try_imports |= try_ipts + + return modules, try_imports, local_mods + + +def file_import_modules(fpath, fdata): + """Get single file all imported modules.""" + modules = ImportedModules() + str_codes = collections.deque([(fdata, 1)]) + try_imports = set() + + while str_codes: + str_code, lineno = str_codes.popleft() + ic = ImportChecker(fpath, lineno) + try: + parsed = ast.parse(str_code) + ic.visit(parsed) + # Ignore SyntaxError in Python code. + except SyntaxError: + pass + modules |= ic.modules + str_codes.extend(ic.str_codes) + try_imports |= ic.try_imports + del ic + + return modules, try_imports + + +class ImportChecker(object): + + def __init__(self, fpath, lineno): + self._fpath = fpath + self._lineno = lineno - 1 + self._modules = ImportedModules() + self._str_codes = collections.deque() + self._try_imports = set() + + def visit_Import(self, node, try_=False): + """As we know: `import a [as b]`.""" + lineno = node.lineno + self._lineno + for alias in node.names: + self._modules.add(alias.name, self._fpath, lineno) + if try_: + self._try_imports.add(alias.name) + + def visit_ImportFrom(self, node, try_=False): + """ + As we know: `from a import b [as c]`. If node.level is not 0, + import statement like this `from .a import b`. + """ + mod_name = node.module + level = node.level + if mod_name is None: + level -= 1 + mod_name = '' + for alias in node.names: + name = level*'.' + mod_name + '.' + alias.name + self._modules.add(name, self._fpath, node.lineno + self._lineno) + if try_: + self._try_imports.add(name) + + def visit_TryExcept(self, node): + """ + If modules which imported by `try except` and not found, + maybe them come from other Python version. + """ + for ipt in node.body: + if ipt.__class__.__name__.startswith('Import'): + method = 'visit_' + ipt.__class__.__name__ + getattr(self, method)(ipt, True) + for handler in node.handlers: + for ipt in handler.body: + if ipt.__class__.__name__.startswith('Import'): + method = 'visit_' + ipt.__class__.__name__ + getattr(self, method)(ipt, True) + + # For Python 3.3+ + visit_Try = visit_TryExcept + + def visit_Exec(self, node): + """ + Check `expression` of `exec(expression[, globals[, locals]])`. + **Just available in python 2.** + """ + if hasattr(node.body, 's'): + self._str_codes.append((node.body.s, node.lineno + self._lineno)) + # PR#13: https://github.com/damnever/pigar/pull/13 + # Sometimes exec statement may be called with tuple in Py2.7.6 + elif hasattr(node.body, 'elts') and len(node.body.elts) >= 1: + self._str_codes.append( + (node.body.elts[0].s, node.lineno + self._lineno)) + + def visit_Expr(self, node): + """ + Check `expression` of `eval(expression[, globals[, locals]])`. + Check `expression` of `exec(expression[, globals[, locals]])` + in python 3. + Check `name` of `__import__(name[, globals[, locals[, + fromlist[, level]]]])`. + Check `name` or `package` of `importlib.import_module(name, + package=None)`. + """ + # Built-in functions + value = node.value + if isinstance(value, ast.Call): + if hasattr(value.func, 'id'): + if (value.func.id == 'eval' and + hasattr(node.value.args[0], 's')): + self._str_codes.append( + (node.value.args[0].s, node.lineno + self._lineno)) + # **`exec` function in Python 3.** + elif (value.func.id == 'exec' and + hasattr(node.value.args[0], 's')): + self._str_codes.append( + (node.value.args[0].s, node.lineno + self._lineno)) + # `__import__` function. + elif (value.func.id == '__import__' and + len(node.value.args) > 0 and + hasattr(node.value.args[0], 's')): + self._modules.add(node.value.args[0].s, self._fpath, + node.lineno + self._lineno) + # `import_module` function. + elif getattr(value.func, 'attr', '') == 'import_module': + module = getattr(value.func, 'value', None) + if (module is not None and + getattr(module, 'id', '') == 'importlib'): + args = node.value.args + arg_len = len(args) + if arg_len > 0 and hasattr(args[0], 's'): + name = args[0].s + if not name.startswith('.'): + self._modules.add(name, self._fpath, + node.lineno + self._lineno) + elif arg_len == 2 and hasattr(args[1], 's'): + self._modules.add(args[1].s, self._fpath, + node.lineno + self._lineno) + + def visit_FunctionDef(self, node): + """ + Check docstring of function, if docstring is used for doctest. + """ + docstring = self._parse_docstring(node) + if docstring: + self._str_codes.append((docstring, node.lineno + self._lineno + 2)) + + def visit_ClassDef(self, node): + """ + Check docstring of class, if docstring is used for doctest. + """ + docstring = self._parse_docstring(node) + if docstring: + self._str_codes.append((docstring, node.lineno + self._lineno + 2)) + + def visit(self, node): + """Visit a node, no recursively.""" + for node in ast.walk(node): + method = 'visit_' + node.__class__.__name__ + getattr(self, method, lambda x: x)(node) + + @staticmethod + def _parse_docstring(node): + """Extract code from docstring.""" + docstring = ast.get_docstring(node) + if docstring: + parser = doctest.DocTestParser() + try: + dt = parser.get_doctest(docstring, {}, None, None, None) + except ValueError: + # >>> 'abc' + pass + else: + examples = dt.examples + return '\n'.join([example.source for example in examples]) + return None + + @property + def modules(self): + return self._modules + + @property + def str_codes(self): + return self._str_codes + + @property + def try_imports(self): + return set((name.split('.')[0] if name and '.' in name else name) + for name in self._try_imports) + + +def _checked_cache(func): + checked = dict() + + @functools.wraps(func) + def _wrapper(name): + if name not in checked: + checked[name] = func(name) + return checked[name] + + return _wrapper + + +@_checked_cache +def is_std_or_local_lib(name): + """Check whether it is stdlib module. + True if std lib + False if installed package + str if local library + """ + exist = True + module_info = ('', '', '') + try: + module_info = imp.find_module(name) + except ImportError: + try: + # __import__(name) + importlib.import_module(name) + module_info = imp.find_module(name) + sys.modules.pop(name) + except ImportError: + exist = False + # Testcase: ResourceWarning + if isinstance(module_info[0], FileType): + module_info[0].close() + mpath = module_info[1] + if exist and mpath is not None: + if ('site-packages' in mpath or + 'dist-packages' in mpath or + 'bin/' in mpath and mpath.endswith('.py')): + exist = False + elif ((sys.prefix not in mpath) and + (sys.base_exec_prefix not in mpath) and + (sys.base_prefix not in mpath)): + exist = mpath + + return exist + + +def get_installed_pkgs_detail(): + """ + HACK: bugfix of the original pigar get_installed_pkgs_detail + + Get mapping for import top level name + and install package name with version. + """ + mapping = dict() + + for path in sys.path: + if os.path.isdir(path) and path.rstrip('/').endswith( + ('site-packages', 'dist-packages')): + new_mapping = _search_path(path) + # BUGFIX: + # override with previous, just like python resolves imports, the first match is the one used. + # unlike the original implementation, where the last one is used. + new_mapping.update(mapping) + mapping = new_mapping + + # HACK: prefer tensorflow_gpu over tensorflow + if 'tensorflow_gpu' in new_mapping: + new_mapping['tensorflow'] = new_mapping['tensorflow_gpu'] + + return mapping + + +def _search_path(path): + mapping = dict() + + for file in os.listdir(path): + # Install from PYPI. + if fnmatch.fnmatch(file, '*-info'): + top_level = os.path.join(path, file, 'top_level.txt') + pkg_name, version = file.split('-')[:2] + if version.endswith('dist'): + version = version.rsplit('.', 1)[0] + # Issue for ubuntu: sudo pip install xxx + elif version.endswith('egg'): + version = version.rsplit('.', 1)[0] + mapping[pkg_name] = (pkg_name, version) + if not os.path.isfile(top_level): + continue + with open(top_level, 'r') as f: + for line in f: + mapping[line.strip()] = (pkg_name, version) + + # Install from local and available in GitHub. + elif fnmatch.fnmatch(file, '*-link'): + link = os.path.join(path, file) + if not os.path.isfile(link): + continue + # Link path. + with open(link, 'r') as f: + for line in f: + line = line.strip() + if line != '.': + dev_dir = line + if not dev_dir: + continue + # Egg info path. + info_dir = [_file for _file in os.listdir(dev_dir) + if _file.endswith('egg-info')] + if not info_dir: + continue + info_dir = info_dir[0] + top_level = os.path.join(dev_dir, info_dir, 'top_level.txt') + # Check whether it can be imported. + if not os.path.isfile(top_level): + continue + + # Check .git dir. + git_path = os.path.join(dev_dir, '.git') + if os.path.isdir(git_path): + config = parse_git_config(git_path) + url = config.get('remote "origin"', {}).get('url') + if not url: + continue + branch = 'branch "master"' + if branch not in config: + for section in config: + if 'branch' in section: + branch = section + break + if not branch: + continue + branch = branch.split()[1][1:-1] + + pkg_name = info_dir.split('.egg')[0] + git_url = 'git+{0}@{1}#egg={2}'.format(url, branch, pkg_name) + with open(top_level, 'r') as f: + for line in f: + mapping[line.strip()] = ('-e', git_url) + return mapping diff --git a/trains/utilities/pigar/unpack.py b/trains/utilities/pigar/unpack.py new file mode 100644 index 00000000..bfdafd31 --- /dev/null +++ b/trains/utilities/pigar/unpack.py @@ -0,0 +1,118 @@ +# -*- coding: utf-8 -*- + +from __future__ import print_function, division, absolute_import + +import tarfile +import zipfile +import re +import string +import io + + +class Archive(object): + """Archive provides a consistent interface for unpacking + compressed file. + """ + + def __init__(self, filename, fileobj): + self._filename = filename + self._fileobj = fileobj + self._file = None + self._names = None + self._read = None + + @property + def filename(self): + return self._filename + + @property + def names(self): + """If name list is not required, do not get it.""" + if self._file is None: + self._prepare() + if not hasattr(self, '_namelist'): + self._namelist = self._names() + return self._namelist + + def close(self): + """Close file object.""" + if self._file is not None: + self._file.close() + if hasattr(self, '_namelist'): + del self._namelist + self._filename = self._fileobj = None + self._file = self._names = self._read = None + + def read(self, filename): + """Read one file from archive.""" + if self._file is None: + self._prepare() + return self._read(filename) + + def unpack(self, to_path): + """Unpack compressed files to path.""" + if self._file is None: + self._prepare() + self._safe_extractall(to_path) + + def _prepare(self): + if self._filename.endswith(('.tar.gz', '.tar.bz2', '.tar.xz')): + self._prepare_tarball() + # An .egg file is actually just a .zip file + # with a different extension, .whl too. + elif self._filename.endswith(('.zip', '.egg', '.whl')): + self._prepare_zip() + else: + raise ValueError("unreadable: {0}".format(self._filename)) + + def _safe_extractall(self, to_path='.'): + unsafe = [] + for name in self.names: + if not self.is_safe(name): + unsafe.append(name) + if unsafe: + raise ValueError("unsafe to unpack: {}".format(unsafe)) + self._file.extractall(to_path) + + def _prepare_zip(self): + self._file = zipfile.ZipFile(self._fileobj) + self._names = self._file.namelist + self._read = self._file.read + + def _prepare_tarball(self): + # tarfile has no read method + def _read(filename): + f = self._file.extractfile(filename) + return f.read() + + self._file = tarfile.open(mode='r:*', fileobj=self._fileobj) + self._names = self._file.getnames + self._read = _read + + def is_safe(self, filename): + return not (filename.startswith(("/", "\\")) or + (len(filename) > 1 and filename[1] == ":" and + filename[0] in string.ascii_letter) or + re.search(r"[.][.][/\\]", filename)) + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_value, traceback): + self.close() + + +def top_level(url, data): + """Read top level names from compressed file.""" + sb = io.BytesIO(data) + txt = None + with Archive(url, sb) as archive: + file = None + for name in archive.names: + if name.lower().endswith('top_level.txt'): + file = name + break + if file: + txt = archive.read(file).decode('utf-8') + sb.close() + return [name.replace('/', '.') for name in txt.splitlines()] if txt else [] diff --git a/trains/utilities/pigar/utils.py b/trains/utilities/pigar/utils.py new file mode 100644 index 00000000..6c8ebf34 --- /dev/null +++ b/trains/utilities/pigar/utils.py @@ -0,0 +1,144 @@ +# -*- coding: utf-8 -*- + +from __future__ import print_function, division, absolute_import + +import os +import re +import sys +import difflib + + +PY32 = sys.version_info[:2] == (3, 2) + +if sys.version_info[0] == 3: + binary_type = bytes +else: + binary_type = str + + +class Dict(dict): + """Convert dict key object to attribute.""" + + def __init__(self, *args, **kwargs): + super(Dict, self).__init__(*args, **kwargs) + + def __getattr__(self, name): + try: + return self[name] + except KeyError: + raise AttributeError('"{0}"'.format(name)) + + def __setattr__(self, name, value): + self[name] = value + + +def parse_reqs(fpath): + pkg_v_re = re.compile(r'^(?P[^><==]+)[><==]{,2}(?P.*)$') + """Parse requirements file.""" + reqs = dict() + with open(fpath, 'r') as f: + for line in f: + if line.startswith('#'): + continue + m = pkg_v_re.match(line.strip()) + if m: + d = m.groupdict() + reqs[d['pkg'].strip()] = d['version'].strip() + return reqs + + +def cmp_to_key(cmp_func): + """Convert a cmp=fcuntion into a key=function.""" + class K(object): + def __init__(self, obj, *args): + self.obj = obj + + def __lt__(self, other): + return cmp_func(self.obj, other.obj) < 0 + + def __gt__(self, other): + return cmp_func(self.obj, other.obj) > 0 + + def __eq__(self, other): + return cmp_func(self.obj, other.obj) == 0 + + return K + + +def compare_version(version1, version2): + """Compare version number, such as 1.1.1 and 1.1b2.0.""" + v1, v2 = list(), list() + + for item in version1.split('.'): + if item.isdigit(): + v1.append(int(item)) + else: + v1.extend([i for i in _group_alnum(item)]) + for item in version2.split('.'): + if item.isdigit(): + v2.append(int(item)) + else: + v2.extend([i for i in _group_alnum(item)]) + + while v1 and v2: + item1, item2 = v1.pop(0), v2.pop(0) + if item1 > item2: + return 1 + elif item1 < item2: + return -1 + + if v1: + return 1 + elif v2: + return -1 + return 0 + + +def _group_alnum(s): + tmp = list() + flag = 1 if s[0].isdigit() else 0 + for c in s: + if c.isdigit(): + if flag == 0: + yield ''.join(tmp) + tmp = list() + flag = 1 + tmp.append(c) + elif c.isalpha(): + if flag == 1: + yield int(''.join(tmp)) + tmp = list() + flag = 0 + tmp.append(c) + last = ''.join(tmp) + yield (int(last) if flag else last) + + +def parse_git_config(path): + """Parse git config file.""" + config = dict() + section = None + + with open(os.path.join(path, 'config'), 'r') as f: + for line in f: + line = line.strip() + if line.startswith('['): + section = line[1: -1].strip() + config[section] = dict() + elif section: + key, value = line.replace(' ', '').split('=') + config[section][key] = value + return config + + +def lines_diff(lines1, lines2): + """Show difference between lines.""" + is_diff = False + diffs = list() + + for line in difflib.ndiff(lines1, lines2): + if not is_diff and line[0] in ('+', '-'): + is_diff = True + diffs.append(line) + + return is_diff, diffs