mirror of
https://github.com/clearml/clearml
synced 2025-04-27 09:49:14 +00:00
Integrate pigar into Trains
This commit is contained in:
parent
8ee2bd1844
commit
146da439e7
@ -11,7 +11,6 @@ jsonschema>=2.6.0
|
||||
numpy>=1.10
|
||||
pathlib2>=2.3.0
|
||||
Pillow>=4.1.1
|
||||
pigar==0.9.2
|
||||
plotly>=3.9.0
|
||||
psutil>=3.4.2
|
||||
pyparsing>=2.0.3
|
||||
|
@ -26,103 +26,18 @@ class ScriptRequirements(object):
|
||||
def __init__(self, root_folder):
|
||||
self._root_folder = root_folder
|
||||
|
||||
@staticmethod
|
||||
def get_installed_pkgs_detail(reqs):
|
||||
"""
|
||||
HACK: bugfix of the original pigar get_installed_pkgs_detail
|
||||
|
||||
Get mapping for import top level name
|
||||
and install package name with version.
|
||||
"""
|
||||
mapping = dict()
|
||||
|
||||
for path in sys.path:
|
||||
if os.path.isdir(path) and path.rstrip('/').endswith(
|
||||
('site-packages', 'dist-packages')):
|
||||
new_mapping = reqs._search_path(path)
|
||||
# BUGFIX:
|
||||
# override with previous, just like python resolves imports, the first match is the one used.
|
||||
# unlike the original implementation, where the last one is used.
|
||||
new_mapping.update(mapping)
|
||||
mapping = new_mapping
|
||||
|
||||
# HACK: prefer tensorflow_gpu over tensorflow
|
||||
if 'tensorflow_gpu' in new_mapping:
|
||||
new_mapping['tensorflow'] = new_mapping['tensorflow_gpu']
|
||||
|
||||
return mapping
|
||||
|
||||
def get_requirements(self):
|
||||
try:
|
||||
from pigar import reqs
|
||||
reqs.project_import_modules = ScriptRequirements._patched_project_import_modules
|
||||
from pigar.__main__ import GenerateReqs
|
||||
from pigar.log import logger
|
||||
logger.setLevel(logging.WARNING)
|
||||
try:
|
||||
# first try our version, if we fail revert to the internal implantation
|
||||
installed_pkgs = self.get_installed_pkgs_detail(reqs)
|
||||
except Exception:
|
||||
installed_pkgs = reqs.get_installed_pkgs_detail()
|
||||
from ....utilities.pigar.reqs import get_installed_pkgs_detail
|
||||
from ....utilities.pigar.__main__ import GenerateReqs
|
||||
installed_pkgs = get_installed_pkgs_detail()
|
||||
gr = GenerateReqs(save_path='', project_path=self._root_folder, installed_pkgs=installed_pkgs,
|
||||
ignores=['.git', '.hg', '.idea', '__pycache__', '.ipynb_checkpoints'])
|
||||
reqs, try_imports, guess = gr.extract_reqs()
|
||||
return self.create_requirements_txt(reqs)
|
||||
reqs, try_imports, guess, local_pks = gr.extract_reqs(module_callback=ScriptRequirements.add_trains_used_packages)
|
||||
return self.create_requirements_txt(reqs, local_pks)
|
||||
except Exception:
|
||||
return '', ''
|
||||
|
||||
@staticmethod
|
||||
def _patched_project_import_modules(project_path, ignores):
|
||||
"""
|
||||
copied form pigar req.project_import_modules
|
||||
patching, os.getcwd() is incorrectly used
|
||||
"""
|
||||
from pigar.modules import ImportedModules
|
||||
from pigar.reqs import file_import_modules
|
||||
modules = ImportedModules()
|
||||
try_imports = set()
|
||||
local_mods = list()
|
||||
ignore_paths = collections.defaultdict(set)
|
||||
if not ignores:
|
||||
ignore_paths[project_path].add('.git')
|
||||
else:
|
||||
for path in ignores:
|
||||
parent_dir = os.path.dirname(path)
|
||||
ignore_paths[parent_dir].add(os.path.basename(path))
|
||||
|
||||
if os.path.isfile(project_path):
|
||||
fake_path = Path(project_path).name
|
||||
with open(project_path, 'rb') as f:
|
||||
fmodules, try_ipts = file_import_modules(fake_path, f.read())
|
||||
modules |= fmodules
|
||||
try_imports |= try_ipts
|
||||
else:
|
||||
cur_dir = project_path # os.getcwd()
|
||||
for dirpath, dirnames, files in os.walk(project_path, followlinks=True):
|
||||
if dirpath in ignore_paths:
|
||||
dirnames[:] = [d for d in dirnames
|
||||
if d not in ignore_paths[dirpath]]
|
||||
py_files = list()
|
||||
for fn in files:
|
||||
# C extension.
|
||||
if fn.endswith('.so'):
|
||||
local_mods.append(fn[:-3])
|
||||
# Normal Python file.
|
||||
if fn.endswith('.py'):
|
||||
local_mods.append(fn[:-3])
|
||||
py_files.append(fn)
|
||||
if '__init__.py' in files:
|
||||
local_mods.append(os.path.basename(dirpath))
|
||||
for file in py_files:
|
||||
fpath = os.path.join(dirpath, file)
|
||||
fake_path = fpath.split(cur_dir)[1][1:]
|
||||
with open(fpath, 'rb') as f:
|
||||
fmodules, try_ipts = file_import_modules(fake_path, f.read())
|
||||
modules |= fmodules
|
||||
try_imports |= try_ipts
|
||||
|
||||
return ScriptRequirements.add_trains_used_packages(modules), try_imports, local_mods
|
||||
|
||||
@staticmethod
|
||||
def add_trains_used_packages(modules):
|
||||
# hack: forcefully insert storage modules if we have them
|
||||
@ -159,7 +74,7 @@ class ScriptRequirements(object):
|
||||
return modules
|
||||
|
||||
@staticmethod
|
||||
def create_requirements_txt(reqs):
|
||||
def create_requirements_txt(reqs, local_pks=None):
|
||||
# write requirements.txt
|
||||
try:
|
||||
conda_requirements = ''
|
||||
@ -188,6 +103,11 @@ class ScriptRequirements(object):
|
||||
# python version header
|
||||
requirements_txt = '# Python ' + sys.version.replace('\n', ' ').replace('\r', ' ') + '\n'
|
||||
|
||||
if local_pks:
|
||||
requirements_txt += '\n# Local modules found - skipping:\n'
|
||||
for k, v in local_pks.sorted_items():
|
||||
requirements_txt += '# {0} == {1}\n'.format(k, v.version)
|
||||
|
||||
# requirement summary
|
||||
requirements_txt += '\n'
|
||||
for k, v in reqs.sorted_items():
|
||||
@ -203,6 +123,13 @@ class ScriptRequirements(object):
|
||||
requirements_txt += '\n' + \
|
||||
'# Detailed import analysis\n' \
|
||||
'# **************************\n'
|
||||
|
||||
if local_pks:
|
||||
for k, v in local_pks.sorted_items():
|
||||
requirements_txt += '\n'
|
||||
requirements_txt += '# IMPORT LOCAL PACKAGE {0}\n'.format(k)
|
||||
requirements_txt += ''.join(['# {0}\n'.format(c) for c in v.comments.sorted_items()])
|
||||
|
||||
for k, v in reqs.sorted_items():
|
||||
requirements_txt += '\n'
|
||||
if k == '-e':
|
||||
@ -262,9 +189,9 @@ class _JupyterObserver(object):
|
||||
# load pigar
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
from pigar.reqs import get_installed_pkgs_detail, file_import_modules
|
||||
from pigar.modules import ReqsModules
|
||||
from pigar.log import logger
|
||||
from ....utilities.pigar.reqs import get_installed_pkgs_detail, file_import_modules
|
||||
from ....utilities.pigar.modules import ReqsModules
|
||||
from ....utilities.pigar.log import logger
|
||||
logger.setLevel(logging.WARNING)
|
||||
except Exception:
|
||||
file_import_modules = None
|
||||
|
1
trains/utilities/pigar/__init__.py
Normal file
1
trains/utilities/pigar/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
# -*- coding: utf-8 -*-
|
133
trains/utilities/pigar/__main__.py
Normal file
133
trains/utilities/pigar/__main__.py
Normal file
@ -0,0 +1,133 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
from __future__ import print_function, division, absolute_import
|
||||
|
||||
import os
|
||||
import codecs
|
||||
|
||||
from .reqs import project_import_modules, is_std_or_local_lib
|
||||
from .utils import lines_diff
|
||||
from .log import logger
|
||||
from .modules import ReqsModules
|
||||
|
||||
|
||||
class GenerateReqs(object):
|
||||
|
||||
def __init__(self, save_path, project_path, ignores,
|
||||
installed_pkgs, comparison_operator='=='):
|
||||
self._save_path = save_path
|
||||
self._project_path = project_path
|
||||
self._ignores = ignores
|
||||
self._installed_pkgs = installed_pkgs
|
||||
self._maybe_local_mods = set()
|
||||
self._local_mods = dict()
|
||||
self._comparison_operator = comparison_operator
|
||||
|
||||
def extract_reqs(self, module_callback=None):
|
||||
"""Extract requirements from project."""
|
||||
|
||||
reqs = ReqsModules()
|
||||
guess = ReqsModules()
|
||||
local = ReqsModules()
|
||||
modules, try_imports, local_mods = project_import_modules(
|
||||
self._project_path, self._ignores)
|
||||
if module_callback:
|
||||
modules = module_callback(modules)
|
||||
app_name = os.path.basename(self._project_path)
|
||||
if app_name in local_mods:
|
||||
local_mods.remove(app_name)
|
||||
|
||||
# Filtering modules
|
||||
candidates = self._filter_modules(modules, local_mods)
|
||||
|
||||
logger.info('Check module in local environment.')
|
||||
for name in candidates:
|
||||
logger.info('Checking module: %s', name)
|
||||
if name in self._installed_pkgs:
|
||||
pkg_name, version = self._installed_pkgs[name]
|
||||
reqs.add(pkg_name, version, modules[name])
|
||||
else:
|
||||
guess.add(name, 0, modules[name])
|
||||
|
||||
# add local modules, so we know what is used but not installed.
|
||||
for name in self._local_mods:
|
||||
if name in modules:
|
||||
relpath = os.path.relpath(self._local_mods[name], self._project_path)
|
||||
if not relpath.startswith('.'):
|
||||
relpath = '.' + os.path.sep + relpath
|
||||
local.add(name, relpath, modules[name])
|
||||
|
||||
return reqs, try_imports, guess, local
|
||||
|
||||
def _write_reqs(self, reqs):
|
||||
print('Writing requirements to "{0}"'.format(
|
||||
self._save_path))
|
||||
with open(self._save_path, 'w+') as f:
|
||||
f.write('# Requirements automatically generated by pigar.\n'
|
||||
'# https://github.com/damnever/pigar\n')
|
||||
for k, v in reqs.sorted_items():
|
||||
f.write('\n')
|
||||
f.write(''.join(['# {0}\n'.format(c)
|
||||
for c in v.comments.sorted_items()]))
|
||||
if k == '-e':
|
||||
f.write('{0} {1}\n'.format(k, v.version))
|
||||
elif v:
|
||||
f.write('{0} {1} {2}\n'.format(
|
||||
k, self._comparison_operator, v.version))
|
||||
else:
|
||||
f.write('{0}\n'.format(k))
|
||||
|
||||
def _best_matchs(self, name, pkgs):
|
||||
# If imported name equals to package name.
|
||||
if name in pkgs:
|
||||
return [pkgs[pkgs.index(name)]]
|
||||
# If not, return all possible packages.
|
||||
return pkgs
|
||||
|
||||
def _filter_modules(self, modules, local_mods):
|
||||
candidates = set()
|
||||
|
||||
logger.info('Filtering modules ...')
|
||||
for module in modules:
|
||||
logger.info('Checking module: %s', module)
|
||||
if not module or module.startswith('.'):
|
||||
continue
|
||||
if module in local_mods:
|
||||
self._maybe_local_mods.add(module)
|
||||
module_std_local = is_std_or_local_lib(module)
|
||||
if module_std_local is True:
|
||||
continue
|
||||
if isinstance(module_std_local, str):
|
||||
self._local_mods[module] = module_std_local
|
||||
continue
|
||||
candidates.add(module)
|
||||
|
||||
return candidates
|
||||
|
||||
def _invalid_reqs(self, reqs):
|
||||
for name, detail in reqs.sorted_items():
|
||||
print(
|
||||
' {0} referenced from:\n {1}'.format(
|
||||
name,
|
||||
'\n '.join(detail.comments.sorted_items())
|
||||
)
|
||||
)
|
||||
|
||||
def _save_old_reqs(self):
|
||||
if os.path.isfile(self._save_path):
|
||||
with codecs.open(self._save_path, 'rb', 'utf-8') as f:
|
||||
self._old_reqs = f.readlines()
|
||||
|
||||
def _reqs_diff(self):
|
||||
if not hasattr(self, '_old_reqs'):
|
||||
return
|
||||
with codecs.open(self._save_path, 'rb', 'utf-8') as f:
|
||||
new_reqs = f.readlines()
|
||||
is_diff, diffs = lines_diff(self._old_reqs, new_reqs)
|
||||
msg = 'Requirements file has been covered, '
|
||||
if is_diff:
|
||||
msg += 'there is the difference:'
|
||||
print('{0}\n{1}'.format(msg, ''.join(diffs)), end='')
|
||||
else:
|
||||
msg += 'no difference.'
|
||||
print(msg)
|
12
trains/utilities/pigar/extractor/__init__.py
Normal file
12
trains/utilities/pigar/extractor/__init__.py
Normal file
@ -0,0 +1,12 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
from __future__ import print_function, division, absolute_import
|
||||
|
||||
|
||||
from ..utils import PY32
|
||||
|
||||
|
||||
if PY32:
|
||||
from .thread_extractor import ThreadExtractor as Extractor
|
||||
else:
|
||||
from .gevent_extractor import GeventExtractor as Extractor
|
36
trains/utilities/pigar/extractor/extractor.py
Normal file
36
trains/utilities/pigar/extractor/extractor.py
Normal file
@ -0,0 +1,36 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
from __future__ import print_function, division, absolute_import
|
||||
|
||||
import multiprocessing
|
||||
|
||||
|
||||
class BaseExtractor(object):
|
||||
|
||||
def __init__(self, names, max_workers=None):
|
||||
self._names = names
|
||||
self._max_workers = max_workers or (multiprocessing.cpu_count() * 4)
|
||||
|
||||
def run(self, job):
|
||||
try:
|
||||
self.extract(job)
|
||||
self.wait_complete()
|
||||
except KeyboardInterrupt:
|
||||
print('** Shutting down ...')
|
||||
self.shutdown()
|
||||
else:
|
||||
print('^.^ Extracting all packages done!')
|
||||
finally:
|
||||
self.final()
|
||||
|
||||
def extract(self, job):
|
||||
raise NotImplemented
|
||||
|
||||
def wait_complete(self):
|
||||
raise NotImplemented
|
||||
|
||||
def shutdown(self):
|
||||
raise NotImplemented
|
||||
|
||||
def final(self):
|
||||
pass
|
50
trains/utilities/pigar/extractor/gevent_extractor.py
Normal file
50
trains/utilities/pigar/extractor/gevent_extractor.py
Normal file
@ -0,0 +1,50 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
from __future__ import print_function, division, absolute_import
|
||||
|
||||
import sys
|
||||
|
||||
import greenlet
|
||||
from gevent.pool import Pool
|
||||
|
||||
from .extractor import BaseExtractor
|
||||
from ..log import logger
|
||||
|
||||
|
||||
class GeventExtractor(BaseExtractor):
|
||||
|
||||
def __init__(self, names, max_workers=222):
|
||||
super(self.__class__, self).__init__(names, max_workers)
|
||||
self._pool = Pool(self._max_workers)
|
||||
self._exited_greenlets = 0
|
||||
|
||||
def extract(self, job):
|
||||
job = self._job_wrapper(job)
|
||||
for name in self._names:
|
||||
if self._pool.full():
|
||||
self._pool.wait_available()
|
||||
self._pool.spawn(job, name)
|
||||
|
||||
def _job_wrapper(self, job):
|
||||
def _job(name):
|
||||
result = None
|
||||
try:
|
||||
result = job(name)
|
||||
except greenlet.GreenletExit:
|
||||
self._exited_greenlets += 1
|
||||
except Exception:
|
||||
e = sys.exc_info()[1]
|
||||
logger.error('Extracting "{0}", got: {1}'.format(name, e))
|
||||
return result
|
||||
return _job
|
||||
|
||||
def wait_complete(self):
|
||||
self._pool.join()
|
||||
|
||||
def shutdown(self):
|
||||
self._pool.kill(block=True)
|
||||
|
||||
def final(self):
|
||||
count = self._exited_greenlets
|
||||
if count != 0:
|
||||
print('** {0} running job exited.'.format(count))
|
45
trains/utilities/pigar/extractor/thread_extractor.py
Normal file
45
trains/utilities/pigar/extractor/thread_extractor.py
Normal file
@ -0,0 +1,45 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
from __future__ import print_function, division, absolute_import
|
||||
|
||||
import concurrent.futures
|
||||
|
||||
from .extractor import BaseExtractor
|
||||
|
||||
from ..log import logger
|
||||
|
||||
|
||||
class ThreadExtractor(BaseExtractor):
|
||||
"""Extractor use thread pool execute tasks.
|
||||
|
||||
Can be used to extract /simple/<pkg_name> or /pypi/<pkg_name>/json.
|
||||
|
||||
FIXME: can not deliver SIG_INT to threads in Python 2.
|
||||
"""
|
||||
|
||||
def __init__(self, names, max_workers=None):
|
||||
super(self.__class__, self).__init__(names, max_workers)
|
||||
self._futures = dict()
|
||||
|
||||
def extract(self, job):
|
||||
"""Extract url by package name."""
|
||||
with concurrent.futures.ThreadPoolExecutor(
|
||||
max_workers=self._max_workers) as executor:
|
||||
for name in self._names:
|
||||
self._futures[executor.submit(job, name)] = name
|
||||
|
||||
def wait_complete(self):
|
||||
"""Wait for futures complete done."""
|
||||
for future in concurrent.futures.as_completed(self._futures.keys()):
|
||||
try:
|
||||
error = future.exception()
|
||||
except concurrent.futures.CancelledError:
|
||||
break
|
||||
name = self._futures[future]
|
||||
if error is not None:
|
||||
err_msg = 'Extracting "{0}", got: {1}'.format(name, error)
|
||||
logger.error(err_msg)
|
||||
|
||||
def shutdown(self):
|
||||
for future in self._futures:
|
||||
future.cancel()
|
9
trains/utilities/pigar/log.py
Normal file
9
trains/utilities/pigar/log.py
Normal file
@ -0,0 +1,9 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
from __future__ import print_function, division, absolute_import
|
||||
|
||||
import logging.handlers
|
||||
|
||||
|
||||
logger = logging.getLogger('pigar')
|
||||
logger.setLevel(logging.WARNING)
|
111
trains/utilities/pigar/modules.py
Normal file
111
trains/utilities/pigar/modules.py
Normal file
@ -0,0 +1,111 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
from __future__ import print_function, division, absolute_import
|
||||
|
||||
import collections
|
||||
|
||||
|
||||
# FIXME: Just a workaround, not a radical cure..
|
||||
_special_cases = {
|
||||
"dogpile.cache": "dogpile.cache",
|
||||
"dogpile.core": "dogpile.core",
|
||||
"ruamel.yaml": "ruamel.yaml",
|
||||
"ruamel.ordereddict": "ruamel.ordereddict",
|
||||
}
|
||||
|
||||
|
||||
class Modules(dict):
|
||||
"""Modules object will be used to store modules information."""
|
||||
|
||||
def __init__(self):
|
||||
super(Modules, self).__init__()
|
||||
|
||||
|
||||
class ImportedModules(Modules):
|
||||
|
||||
def __init__(self):
|
||||
super(ImportedModules, self).__init__()
|
||||
|
||||
def add(self, name, file, lineno):
|
||||
if name is None:
|
||||
return
|
||||
|
||||
names = list()
|
||||
special_name = '.'.join(name.split('.')[:2])
|
||||
# Flask extension.
|
||||
if name.startswith('flask.ext.'):
|
||||
names.append('flask')
|
||||
names.append('flask_' + name.split('.')[2])
|
||||
# Special cases..
|
||||
elif special_name in _special_cases:
|
||||
names.append(_special_cases[special_name])
|
||||
# Other.
|
||||
elif '.' in name and not name.startswith('.'):
|
||||
names.append(name.split('.')[0])
|
||||
else:
|
||||
names.append(name)
|
||||
|
||||
for nm in names:
|
||||
if nm not in self:
|
||||
self[nm] = _Locations()
|
||||
self[nm].add(file, lineno)
|
||||
|
||||
def __or__(self, obj):
|
||||
for name, locations in obj.items():
|
||||
for file, linenos in locations.items():
|
||||
for lineno in linenos:
|
||||
self.add(name, file, lineno)
|
||||
return self
|
||||
|
||||
|
||||
class ReqsModules(Modules):
|
||||
|
||||
_Detail = collections.namedtuple('Detail', ['version', 'comments'])
|
||||
|
||||
def __init__(self):
|
||||
super(ReqsModules, self).__init__()
|
||||
self._sorted = None
|
||||
|
||||
def add(self, package, version, locations):
|
||||
if package in self:
|
||||
self[package].comments.extend(locations)
|
||||
else:
|
||||
self[package] = self._Detail(version, locations)
|
||||
|
||||
def sorted_items(self):
|
||||
if self._sorted is None:
|
||||
self._sorted = sorted(self.items())
|
||||
return self._sorted
|
||||
|
||||
def remove(self, *names):
|
||||
for name in names:
|
||||
if name in self:
|
||||
self.pop(name)
|
||||
self._sorted = None
|
||||
|
||||
|
||||
class _Locations(dict):
|
||||
"""_Locations store code locations(file, linenos)."""
|
||||
|
||||
def __init__(self):
|
||||
super(_Locations, self).__init__()
|
||||
self._sorted = None
|
||||
|
||||
def add(self, file, lineno):
|
||||
if file in self and lineno not in self[file]:
|
||||
self[file].append(lineno)
|
||||
else:
|
||||
self[file] = [lineno]
|
||||
|
||||
def extend(self, obj):
|
||||
for file, linenos in obj.items():
|
||||
for lineno in linenos:
|
||||
self.add(file, lineno)
|
||||
|
||||
def sorted_items(self):
|
||||
if self._sorted is None:
|
||||
self._sorted = [
|
||||
'{0}: {1}'.format(f, ','.join([str(n) for n in sorted(ls)]))
|
||||
for f, ls in sorted(self.items())
|
||||
]
|
||||
return self._sorted
|
398
trains/utilities/pigar/reqs.py
Normal file
398
trains/utilities/pigar/reqs.py
Normal file
@ -0,0 +1,398 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
from __future__ import print_function, division, absolute_import
|
||||
|
||||
import os
|
||||
import sys
|
||||
import fnmatch
|
||||
import importlib
|
||||
import imp
|
||||
import ast
|
||||
import doctest
|
||||
import collections
|
||||
import functools
|
||||
from pathlib2 import Path
|
||||
try:
|
||||
from types import FileType # py2
|
||||
except ImportError:
|
||||
from io import IOBase as FileType # py3
|
||||
|
||||
from .log import logger
|
||||
from .utils import parse_git_config
|
||||
from .modules import ImportedModules
|
||||
|
||||
|
||||
def project_import_modules(project_path, ignores):
|
||||
"""
|
||||
copied form pigar req.project_import_modules patching, os.getcwd() is incorrectly used
|
||||
"""
|
||||
modules = ImportedModules()
|
||||
try_imports = set()
|
||||
local_mods = list()
|
||||
ignore_paths = collections.defaultdict(set)
|
||||
if not ignores:
|
||||
ignore_paths[project_path].add('.git')
|
||||
else:
|
||||
for path in ignores:
|
||||
parent_dir = os.path.dirname(path)
|
||||
ignore_paths[parent_dir].add(os.path.basename(path))
|
||||
|
||||
if os.path.isfile(project_path):
|
||||
fake_path = Path(project_path).name
|
||||
with open(project_path, 'rb') as f:
|
||||
fmodules, try_ipts = file_import_modules(fake_path, f.read())
|
||||
modules |= fmodules
|
||||
try_imports |= try_ipts
|
||||
else:
|
||||
cur_dir = project_path # os.getcwd()
|
||||
for dirpath, dirnames, files in os.walk(project_path, followlinks=True):
|
||||
if dirpath in ignore_paths:
|
||||
dirnames[:] = [d for d in dirnames
|
||||
if d not in ignore_paths[dirpath]]
|
||||
py_files = list()
|
||||
for fn in files:
|
||||
# C extension.
|
||||
if fn.endswith('.so'):
|
||||
local_mods.append(fn[:-3])
|
||||
# Normal Python file.
|
||||
if fn.endswith('.py'):
|
||||
local_mods.append(fn[:-3])
|
||||
py_files.append(fn)
|
||||
if '__init__.py' in files:
|
||||
local_mods.append(os.path.basename(dirpath))
|
||||
for file in py_files:
|
||||
fpath = os.path.join(dirpath, file)
|
||||
fake_path = fpath.split(cur_dir)[1][1:]
|
||||
with open(fpath, 'rb') as f:
|
||||
fmodules, try_ipts = file_import_modules(fake_path, f.read())
|
||||
modules |= fmodules
|
||||
try_imports |= try_ipts
|
||||
|
||||
return modules, try_imports, local_mods
|
||||
|
||||
|
||||
def file_import_modules(fpath, fdata):
|
||||
"""Get single file all imported modules."""
|
||||
modules = ImportedModules()
|
||||
str_codes = collections.deque([(fdata, 1)])
|
||||
try_imports = set()
|
||||
|
||||
while str_codes:
|
||||
str_code, lineno = str_codes.popleft()
|
||||
ic = ImportChecker(fpath, lineno)
|
||||
try:
|
||||
parsed = ast.parse(str_code)
|
||||
ic.visit(parsed)
|
||||
# Ignore SyntaxError in Python code.
|
||||
except SyntaxError:
|
||||
pass
|
||||
modules |= ic.modules
|
||||
str_codes.extend(ic.str_codes)
|
||||
try_imports |= ic.try_imports
|
||||
del ic
|
||||
|
||||
return modules, try_imports
|
||||
|
||||
|
||||
class ImportChecker(object):
|
||||
|
||||
def __init__(self, fpath, lineno):
|
||||
self._fpath = fpath
|
||||
self._lineno = lineno - 1
|
||||
self._modules = ImportedModules()
|
||||
self._str_codes = collections.deque()
|
||||
self._try_imports = set()
|
||||
|
||||
def visit_Import(self, node, try_=False):
|
||||
"""As we know: `import a [as b]`."""
|
||||
lineno = node.lineno + self._lineno
|
||||
for alias in node.names:
|
||||
self._modules.add(alias.name, self._fpath, lineno)
|
||||
if try_:
|
||||
self._try_imports.add(alias.name)
|
||||
|
||||
def visit_ImportFrom(self, node, try_=False):
|
||||
"""
|
||||
As we know: `from a import b [as c]`. If node.level is not 0,
|
||||
import statement like this `from .a import b`.
|
||||
"""
|
||||
mod_name = node.module
|
||||
level = node.level
|
||||
if mod_name is None:
|
||||
level -= 1
|
||||
mod_name = ''
|
||||
for alias in node.names:
|
||||
name = level*'.' + mod_name + '.' + alias.name
|
||||
self._modules.add(name, self._fpath, node.lineno + self._lineno)
|
||||
if try_:
|
||||
self._try_imports.add(name)
|
||||
|
||||
def visit_TryExcept(self, node):
|
||||
"""
|
||||
If modules which imported by `try except` and not found,
|
||||
maybe them come from other Python version.
|
||||
"""
|
||||
for ipt in node.body:
|
||||
if ipt.__class__.__name__.startswith('Import'):
|
||||
method = 'visit_' + ipt.__class__.__name__
|
||||
getattr(self, method)(ipt, True)
|
||||
for handler in node.handlers:
|
||||
for ipt in handler.body:
|
||||
if ipt.__class__.__name__.startswith('Import'):
|
||||
method = 'visit_' + ipt.__class__.__name__
|
||||
getattr(self, method)(ipt, True)
|
||||
|
||||
# For Python 3.3+
|
||||
visit_Try = visit_TryExcept
|
||||
|
||||
def visit_Exec(self, node):
|
||||
"""
|
||||
Check `expression` of `exec(expression[, globals[, locals]])`.
|
||||
**Just available in python 2.**
|
||||
"""
|
||||
if hasattr(node.body, 's'):
|
||||
self._str_codes.append((node.body.s, node.lineno + self._lineno))
|
||||
# PR#13: https://github.com/damnever/pigar/pull/13
|
||||
# Sometimes exec statement may be called with tuple in Py2.7.6
|
||||
elif hasattr(node.body, 'elts') and len(node.body.elts) >= 1:
|
||||
self._str_codes.append(
|
||||
(node.body.elts[0].s, node.lineno + self._lineno))
|
||||
|
||||
def visit_Expr(self, node):
|
||||
"""
|
||||
Check `expression` of `eval(expression[, globals[, locals]])`.
|
||||
Check `expression` of `exec(expression[, globals[, locals]])`
|
||||
in python 3.
|
||||
Check `name` of `__import__(name[, globals[, locals[,
|
||||
fromlist[, level]]]])`.
|
||||
Check `name` or `package` of `importlib.import_module(name,
|
||||
package=None)`.
|
||||
"""
|
||||
# Built-in functions
|
||||
value = node.value
|
||||
if isinstance(value, ast.Call):
|
||||
if hasattr(value.func, 'id'):
|
||||
if (value.func.id == 'eval' and
|
||||
hasattr(node.value.args[0], 's')):
|
||||
self._str_codes.append(
|
||||
(node.value.args[0].s, node.lineno + self._lineno))
|
||||
# **`exec` function in Python 3.**
|
||||
elif (value.func.id == 'exec' and
|
||||
hasattr(node.value.args[0], 's')):
|
||||
self._str_codes.append(
|
||||
(node.value.args[0].s, node.lineno + self._lineno))
|
||||
# `__import__` function.
|
||||
elif (value.func.id == '__import__' and
|
||||
len(node.value.args) > 0 and
|
||||
hasattr(node.value.args[0], 's')):
|
||||
self._modules.add(node.value.args[0].s, self._fpath,
|
||||
node.lineno + self._lineno)
|
||||
# `import_module` function.
|
||||
elif getattr(value.func, 'attr', '') == 'import_module':
|
||||
module = getattr(value.func, 'value', None)
|
||||
if (module is not None and
|
||||
getattr(module, 'id', '') == 'importlib'):
|
||||
args = node.value.args
|
||||
arg_len = len(args)
|
||||
if arg_len > 0 and hasattr(args[0], 's'):
|
||||
name = args[0].s
|
||||
if not name.startswith('.'):
|
||||
self._modules.add(name, self._fpath,
|
||||
node.lineno + self._lineno)
|
||||
elif arg_len == 2 and hasattr(args[1], 's'):
|
||||
self._modules.add(args[1].s, self._fpath,
|
||||
node.lineno + self._lineno)
|
||||
|
||||
def visit_FunctionDef(self, node):
|
||||
"""
|
||||
Check docstring of function, if docstring is used for doctest.
|
||||
"""
|
||||
docstring = self._parse_docstring(node)
|
||||
if docstring:
|
||||
self._str_codes.append((docstring, node.lineno + self._lineno + 2))
|
||||
|
||||
def visit_ClassDef(self, node):
|
||||
"""
|
||||
Check docstring of class, if docstring is used for doctest.
|
||||
"""
|
||||
docstring = self._parse_docstring(node)
|
||||
if docstring:
|
||||
self._str_codes.append((docstring, node.lineno + self._lineno + 2))
|
||||
|
||||
def visit(self, node):
|
||||
"""Visit a node, no recursively."""
|
||||
for node in ast.walk(node):
|
||||
method = 'visit_' + node.__class__.__name__
|
||||
getattr(self, method, lambda x: x)(node)
|
||||
|
||||
@staticmethod
|
||||
def _parse_docstring(node):
|
||||
"""Extract code from docstring."""
|
||||
docstring = ast.get_docstring(node)
|
||||
if docstring:
|
||||
parser = doctest.DocTestParser()
|
||||
try:
|
||||
dt = parser.get_doctest(docstring, {}, None, None, None)
|
||||
except ValueError:
|
||||
# >>> 'abc'
|
||||
pass
|
||||
else:
|
||||
examples = dt.examples
|
||||
return '\n'.join([example.source for example in examples])
|
||||
return None
|
||||
|
||||
@property
|
||||
def modules(self):
|
||||
return self._modules
|
||||
|
||||
@property
|
||||
def str_codes(self):
|
||||
return self._str_codes
|
||||
|
||||
@property
|
||||
def try_imports(self):
|
||||
return set((name.split('.')[0] if name and '.' in name else name)
|
||||
for name in self._try_imports)
|
||||
|
||||
|
||||
def _checked_cache(func):
|
||||
checked = dict()
|
||||
|
||||
@functools.wraps(func)
|
||||
def _wrapper(name):
|
||||
if name not in checked:
|
||||
checked[name] = func(name)
|
||||
return checked[name]
|
||||
|
||||
return _wrapper
|
||||
|
||||
|
||||
@_checked_cache
|
||||
def is_std_or_local_lib(name):
|
||||
"""Check whether it is stdlib module.
|
||||
True if std lib
|
||||
False if installed package
|
||||
str if local library
|
||||
"""
|
||||
exist = True
|
||||
module_info = ('', '', '')
|
||||
try:
|
||||
module_info = imp.find_module(name)
|
||||
except ImportError:
|
||||
try:
|
||||
# __import__(name)
|
||||
importlib.import_module(name)
|
||||
module_info = imp.find_module(name)
|
||||
sys.modules.pop(name)
|
||||
except ImportError:
|
||||
exist = False
|
||||
# Testcase: ResourceWarning
|
||||
if isinstance(module_info[0], FileType):
|
||||
module_info[0].close()
|
||||
mpath = module_info[1]
|
||||
if exist and mpath is not None:
|
||||
if ('site-packages' in mpath or
|
||||
'dist-packages' in mpath or
|
||||
'bin/' in mpath and mpath.endswith('.py')):
|
||||
exist = False
|
||||
elif ((sys.prefix not in mpath) and
|
||||
(sys.base_exec_prefix not in mpath) and
|
||||
(sys.base_prefix not in mpath)):
|
||||
exist = mpath
|
||||
|
||||
return exist
|
||||
|
||||
|
||||
def get_installed_pkgs_detail():
|
||||
"""
|
||||
HACK: bugfix of the original pigar get_installed_pkgs_detail
|
||||
|
||||
Get mapping for import top level name
|
||||
and install package name with version.
|
||||
"""
|
||||
mapping = dict()
|
||||
|
||||
for path in sys.path:
|
||||
if os.path.isdir(path) and path.rstrip('/').endswith(
|
||||
('site-packages', 'dist-packages')):
|
||||
new_mapping = _search_path(path)
|
||||
# BUGFIX:
|
||||
# override with previous, just like python resolves imports, the first match is the one used.
|
||||
# unlike the original implementation, where the last one is used.
|
||||
new_mapping.update(mapping)
|
||||
mapping = new_mapping
|
||||
|
||||
# HACK: prefer tensorflow_gpu over tensorflow
|
||||
if 'tensorflow_gpu' in new_mapping:
|
||||
new_mapping['tensorflow'] = new_mapping['tensorflow_gpu']
|
||||
|
||||
return mapping
|
||||
|
||||
|
||||
def _search_path(path):
|
||||
mapping = dict()
|
||||
|
||||
for file in os.listdir(path):
|
||||
# Install from PYPI.
|
||||
if fnmatch.fnmatch(file, '*-info'):
|
||||
top_level = os.path.join(path, file, 'top_level.txt')
|
||||
pkg_name, version = file.split('-')[:2]
|
||||
if version.endswith('dist'):
|
||||
version = version.rsplit('.', 1)[0]
|
||||
# Issue for ubuntu: sudo pip install xxx
|
||||
elif version.endswith('egg'):
|
||||
version = version.rsplit('.', 1)[0]
|
||||
mapping[pkg_name] = (pkg_name, version)
|
||||
if not os.path.isfile(top_level):
|
||||
continue
|
||||
with open(top_level, 'r') as f:
|
||||
for line in f:
|
||||
mapping[line.strip()] = (pkg_name, version)
|
||||
|
||||
# Install from local and available in GitHub.
|
||||
elif fnmatch.fnmatch(file, '*-link'):
|
||||
link = os.path.join(path, file)
|
||||
if not os.path.isfile(link):
|
||||
continue
|
||||
# Link path.
|
||||
with open(link, 'r') as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if line != '.':
|
||||
dev_dir = line
|
||||
if not dev_dir:
|
||||
continue
|
||||
# Egg info path.
|
||||
info_dir = [_file for _file in os.listdir(dev_dir)
|
||||
if _file.endswith('egg-info')]
|
||||
if not info_dir:
|
||||
continue
|
||||
info_dir = info_dir[0]
|
||||
top_level = os.path.join(dev_dir, info_dir, 'top_level.txt')
|
||||
# Check whether it can be imported.
|
||||
if not os.path.isfile(top_level):
|
||||
continue
|
||||
|
||||
# Check .git dir.
|
||||
git_path = os.path.join(dev_dir, '.git')
|
||||
if os.path.isdir(git_path):
|
||||
config = parse_git_config(git_path)
|
||||
url = config.get('remote "origin"', {}).get('url')
|
||||
if not url:
|
||||
continue
|
||||
branch = 'branch "master"'
|
||||
if branch not in config:
|
||||
for section in config:
|
||||
if 'branch' in section:
|
||||
branch = section
|
||||
break
|
||||
if not branch:
|
||||
continue
|
||||
branch = branch.split()[1][1:-1]
|
||||
|
||||
pkg_name = info_dir.split('.egg')[0]
|
||||
git_url = 'git+{0}@{1}#egg={2}'.format(url, branch, pkg_name)
|
||||
with open(top_level, 'r') as f:
|
||||
for line in f:
|
||||
mapping[line.strip()] = ('-e', git_url)
|
||||
return mapping
|
118
trains/utilities/pigar/unpack.py
Normal file
118
trains/utilities/pigar/unpack.py
Normal file
@ -0,0 +1,118 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
from __future__ import print_function, division, absolute_import
|
||||
|
||||
import tarfile
|
||||
import zipfile
|
||||
import re
|
||||
import string
|
||||
import io
|
||||
|
||||
|
||||
class Archive(object):
|
||||
"""Archive provides a consistent interface for unpacking
|
||||
compressed file.
|
||||
"""
|
||||
|
||||
def __init__(self, filename, fileobj):
|
||||
self._filename = filename
|
||||
self._fileobj = fileobj
|
||||
self._file = None
|
||||
self._names = None
|
||||
self._read = None
|
||||
|
||||
@property
|
||||
def filename(self):
|
||||
return self._filename
|
||||
|
||||
@property
|
||||
def names(self):
|
||||
"""If name list is not required, do not get it."""
|
||||
if self._file is None:
|
||||
self._prepare()
|
||||
if not hasattr(self, '_namelist'):
|
||||
self._namelist = self._names()
|
||||
return self._namelist
|
||||
|
||||
def close(self):
|
||||
"""Close file object."""
|
||||
if self._file is not None:
|
||||
self._file.close()
|
||||
if hasattr(self, '_namelist'):
|
||||
del self._namelist
|
||||
self._filename = self._fileobj = None
|
||||
self._file = self._names = self._read = None
|
||||
|
||||
def read(self, filename):
|
||||
"""Read one file from archive."""
|
||||
if self._file is None:
|
||||
self._prepare()
|
||||
return self._read(filename)
|
||||
|
||||
def unpack(self, to_path):
|
||||
"""Unpack compressed files to path."""
|
||||
if self._file is None:
|
||||
self._prepare()
|
||||
self._safe_extractall(to_path)
|
||||
|
||||
def _prepare(self):
|
||||
if self._filename.endswith(('.tar.gz', '.tar.bz2', '.tar.xz')):
|
||||
self._prepare_tarball()
|
||||
# An .egg file is actually just a .zip file
|
||||
# with a different extension, .whl too.
|
||||
elif self._filename.endswith(('.zip', '.egg', '.whl')):
|
||||
self._prepare_zip()
|
||||
else:
|
||||
raise ValueError("unreadable: {0}".format(self._filename))
|
||||
|
||||
def _safe_extractall(self, to_path='.'):
|
||||
unsafe = []
|
||||
for name in self.names:
|
||||
if not self.is_safe(name):
|
||||
unsafe.append(name)
|
||||
if unsafe:
|
||||
raise ValueError("unsafe to unpack: {}".format(unsafe))
|
||||
self._file.extractall(to_path)
|
||||
|
||||
def _prepare_zip(self):
|
||||
self._file = zipfile.ZipFile(self._fileobj)
|
||||
self._names = self._file.namelist
|
||||
self._read = self._file.read
|
||||
|
||||
def _prepare_tarball(self):
|
||||
# tarfile has no read method
|
||||
def _read(filename):
|
||||
f = self._file.extractfile(filename)
|
||||
return f.read()
|
||||
|
||||
self._file = tarfile.open(mode='r:*', fileobj=self._fileobj)
|
||||
self._names = self._file.getnames
|
||||
self._read = _read
|
||||
|
||||
def is_safe(self, filename):
|
||||
return not (filename.startswith(("/", "\\")) or
|
||||
(len(filename) > 1 and filename[1] == ":" and
|
||||
filename[0] in string.ascii_letter) or
|
||||
re.search(r"[.][.][/\\]", filename))
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback):
|
||||
self.close()
|
||||
|
||||
|
||||
def top_level(url, data):
|
||||
"""Read top level names from compressed file."""
|
||||
sb = io.BytesIO(data)
|
||||
txt = None
|
||||
with Archive(url, sb) as archive:
|
||||
file = None
|
||||
for name in archive.names:
|
||||
if name.lower().endswith('top_level.txt'):
|
||||
file = name
|
||||
break
|
||||
if file:
|
||||
txt = archive.read(file).decode('utf-8')
|
||||
sb.close()
|
||||
return [name.replace('/', '.') for name in txt.splitlines()] if txt else []
|
144
trains/utilities/pigar/utils.py
Normal file
144
trains/utilities/pigar/utils.py
Normal file
@ -0,0 +1,144 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
from __future__ import print_function, division, absolute_import
|
||||
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
import difflib
|
||||
|
||||
|
||||
PY32 = sys.version_info[:2] == (3, 2)
|
||||
|
||||
if sys.version_info[0] == 3:
|
||||
binary_type = bytes
|
||||
else:
|
||||
binary_type = str
|
||||
|
||||
|
||||
class Dict(dict):
|
||||
"""Convert dict key object to attribute."""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(Dict, self).__init__(*args, **kwargs)
|
||||
|
||||
def __getattr__(self, name):
|
||||
try:
|
||||
return self[name]
|
||||
except KeyError:
|
||||
raise AttributeError('"{0}"'.format(name))
|
||||
|
||||
def __setattr__(self, name, value):
|
||||
self[name] = value
|
||||
|
||||
|
||||
def parse_reqs(fpath):
|
||||
pkg_v_re = re.compile(r'^(?P<pkg>[^><==]+)[><==]{,2}(?P<version>.*)$')
|
||||
"""Parse requirements file."""
|
||||
reqs = dict()
|
||||
with open(fpath, 'r') as f:
|
||||
for line in f:
|
||||
if line.startswith('#'):
|
||||
continue
|
||||
m = pkg_v_re.match(line.strip())
|
||||
if m:
|
||||
d = m.groupdict()
|
||||
reqs[d['pkg'].strip()] = d['version'].strip()
|
||||
return reqs
|
||||
|
||||
|
||||
def cmp_to_key(cmp_func):
|
||||
"""Convert a cmp=fcuntion into a key=function."""
|
||||
class K(object):
|
||||
def __init__(self, obj, *args):
|
||||
self.obj = obj
|
||||
|
||||
def __lt__(self, other):
|
||||
return cmp_func(self.obj, other.obj) < 0
|
||||
|
||||
def __gt__(self, other):
|
||||
return cmp_func(self.obj, other.obj) > 0
|
||||
|
||||
def __eq__(self, other):
|
||||
return cmp_func(self.obj, other.obj) == 0
|
||||
|
||||
return K
|
||||
|
||||
|
||||
def compare_version(version1, version2):
|
||||
"""Compare version number, such as 1.1.1 and 1.1b2.0."""
|
||||
v1, v2 = list(), list()
|
||||
|
||||
for item in version1.split('.'):
|
||||
if item.isdigit():
|
||||
v1.append(int(item))
|
||||
else:
|
||||
v1.extend([i for i in _group_alnum(item)])
|
||||
for item in version2.split('.'):
|
||||
if item.isdigit():
|
||||
v2.append(int(item))
|
||||
else:
|
||||
v2.extend([i for i in _group_alnum(item)])
|
||||
|
||||
while v1 and v2:
|
||||
item1, item2 = v1.pop(0), v2.pop(0)
|
||||
if item1 > item2:
|
||||
return 1
|
||||
elif item1 < item2:
|
||||
return -1
|
||||
|
||||
if v1:
|
||||
return 1
|
||||
elif v2:
|
||||
return -1
|
||||
return 0
|
||||
|
||||
|
||||
def _group_alnum(s):
|
||||
tmp = list()
|
||||
flag = 1 if s[0].isdigit() else 0
|
||||
for c in s:
|
||||
if c.isdigit():
|
||||
if flag == 0:
|
||||
yield ''.join(tmp)
|
||||
tmp = list()
|
||||
flag = 1
|
||||
tmp.append(c)
|
||||
elif c.isalpha():
|
||||
if flag == 1:
|
||||
yield int(''.join(tmp))
|
||||
tmp = list()
|
||||
flag = 0
|
||||
tmp.append(c)
|
||||
last = ''.join(tmp)
|
||||
yield (int(last) if flag else last)
|
||||
|
||||
|
||||
def parse_git_config(path):
|
||||
"""Parse git config file."""
|
||||
config = dict()
|
||||
section = None
|
||||
|
||||
with open(os.path.join(path, 'config'), 'r') as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if line.startswith('['):
|
||||
section = line[1: -1].strip()
|
||||
config[section] = dict()
|
||||
elif section:
|
||||
key, value = line.replace(' ', '').split('=')
|
||||
config[section][key] = value
|
||||
return config
|
||||
|
||||
|
||||
def lines_diff(lines1, lines2):
|
||||
"""Show difference between lines."""
|
||||
is_diff = False
|
||||
diffs = list()
|
||||
|
||||
for line in difflib.ndiff(lines1, lines2):
|
||||
if not is_diff and line[0] in ('+', '-'):
|
||||
is_diff = True
|
||||
diffs.append(line)
|
||||
|
||||
return is_diff, diffs
|
Loading…
Reference in New Issue
Block a user