mirror of
https://github.com/clearml/clearml-agent
synced 2025-06-26 18:16:15 +00:00
Initial release
This commit is contained in:
0
tests/__init__.py
Normal file
0
tests/__init__.py
Normal file
51
tests/conftest.py
Normal file
51
tests/conftest.py
Normal file
@@ -0,0 +1,51 @@
|
||||
from argparse import Namespace
|
||||
from contextlib import contextmanager
|
||||
|
||||
import pytest
|
||||
import yaml
|
||||
from pathlib2 import Path
|
||||
|
||||
PROJECT_ROOT = Path(__file__).resolve().parents[1]
|
||||
|
||||
|
||||
@pytest.fixture(scope='function')
|
||||
def run_trains_agent(script_runner):
|
||||
""" Execute trains_agent agent app in subprocess and return stdout as a string.
|
||||
Args:
|
||||
script_runner (object): a pytest plugin for testing python scripts
|
||||
installed via console_scripts entry point of setup.py.
|
||||
It can run the scripts under test in a separate process or using the interpreter that's running
|
||||
the test suite. The former mode ensures that the script will run in an environment that
|
||||
is identical to normal execution whereas the latter one allows much quicker test runs during development
|
||||
while simulating the real runs as muh as possible.
|
||||
For more details: https://pypi.python.org/pypi/pytest-console-scripts
|
||||
Returns:
|
||||
string: The return value. stdout output
|
||||
"""
|
||||
def _method(*args):
|
||||
trains_agent_file = str(PROJECT_ROOT / "trains_agent.sh")
|
||||
ret = script_runner.run(trains_agent_file, *args)
|
||||
return ret
|
||||
return _method
|
||||
|
||||
|
||||
@pytest.fixture(scope='function')
|
||||
def trains_agentyaml(tmpdir):
|
||||
@contextmanager
|
||||
def _method(template_file):
|
||||
file = tmpdir.join("trains_agent.yaml")
|
||||
with (PROJECT_ROOT / "tests/templates" / template_file).open() as f:
|
||||
code = yaml.load(f)
|
||||
yield Namespace(code=code, file=file.strpath)
|
||||
file.write(yaml.dump(code))
|
||||
return _method
|
||||
|
||||
|
||||
# class Test(object):
|
||||
# def yaml_file(self, tmpdir, template_file):
|
||||
# file = tmpdir.join("trains_agent.yaml")
|
||||
# with open(template_file) as f:
|
||||
# test_object = yaml.load(f)
|
||||
# self.let(test_object)
|
||||
# file.write(yaml.dump(test_object))
|
||||
# return file.strpath
|
||||
0
tests/package/__init__.py
Normal file
0
tests/package/__init__.py
Normal file
35
tests/package/ssh_conversion.py
Normal file
35
tests/package/ssh_conversion.py
Normal file
@@ -0,0 +1,35 @@
|
||||
import pytest
|
||||
|
||||
from trains_agent.helper.repo import VCS
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
["url", "expected"],
|
||||
(
|
||||
("a", None),
|
||||
("foo://a/b", None),
|
||||
("foo://a/b/", None),
|
||||
("https://a/b/", None),
|
||||
("https://example.com/a/b", None),
|
||||
("https://example.com/a/b/", None),
|
||||
("ftp://example.com/a/b", None),
|
||||
("ftp://example.com/a/b/", None),
|
||||
("github.com:foo/bar.git", "https://github.com/foo/bar.git"),
|
||||
("git@github.com:foo/bar.git", "https://github.com/foo/bar.git"),
|
||||
("bitbucket.org:foo/bar.git", "https://bitbucket.org/foo/bar.git"),
|
||||
("hg@bitbucket.org:foo/bar.git", "https://bitbucket.org/foo/bar.git"),
|
||||
("ssh://bitbucket.org/foo/bar.git", "https://bitbucket.org/foo/bar.git"),
|
||||
("ssh://git@github.com/foo/bar.git", "https://github.com/foo/bar.git"),
|
||||
("ssh://user@github.com/foo/bar.git", "https://user@github.com/foo/bar.git"),
|
||||
("ssh://git:password@github.com/foo/bar.git", "https://git:password@github.com/foo/bar.git"),
|
||||
("ssh://user:password@github.com/foo/bar.git", "https://user:password@github.com/foo/bar.git"),
|
||||
("ssh://hg@bitbucket.org/foo/bar.git", "https://bitbucket.org/foo/bar.git"),
|
||||
("ssh://user@bitbucket.org/foo/bar.git", "https://user@bitbucket.org/foo/bar.git"),
|
||||
("ssh://hg:password@bitbucket.org/foo/bar.git", "https://hg:password@bitbucket.org/foo/bar.git"),
|
||||
("ssh://user:password@bitbucket.org/foo/bar.git", "https://user:password@bitbucket.org/foo/bar.git"),
|
||||
),
|
||||
)
|
||||
def test(url, expected):
|
||||
result = VCS.resolve_ssh_url(url)
|
||||
expected = expected or url
|
||||
assert result == expected
|
||||
76
tests/package/test_pip_download_cache.py
Normal file
76
tests/package/test_pip_download_cache.py
Normal file
@@ -0,0 +1,76 @@
|
||||
import pytest
|
||||
from furl import furl
|
||||
|
||||
from trains_agent.helper.package.translator import RequirementsTranslator
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"line",
|
||||
(
|
||||
furl()
|
||||
.set(
|
||||
scheme=scheme,
|
||||
host=host,
|
||||
path=path,
|
||||
query=query,
|
||||
fragment=fragment,
|
||||
port=port,
|
||||
username=username,
|
||||
password=password,
|
||||
)
|
||||
.url
|
||||
for scheme in ("http", "https", "ftp")
|
||||
for host in ("a", "example.com")
|
||||
for path in (None, "/", "a", "a/", "a/b", "a/b/", "a b", "a b ")
|
||||
for query in (None, "foo", "foo=3", "foo=3&bar")
|
||||
for fragment in (None, "foo")
|
||||
for port in (None, 1337)
|
||||
for username in (None, "", "user")
|
||||
for password in (None, "", "password")
|
||||
),
|
||||
)
|
||||
def test_supported(line):
|
||||
assert "://" in line
|
||||
assert RequirementsTranslator.is_supported_link(line)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"line",
|
||||
[
|
||||
"pytorch",
|
||||
"foo",
|
||||
"foo1",
|
||||
"bar",
|
||||
"bar1",
|
||||
"foo-bar",
|
||||
"foo-bar1",
|
||||
"foo-bar-1",
|
||||
"foo_bar",
|
||||
"foo_bar1",
|
||||
"foo_bar_1",
|
||||
" https://a",
|
||||
" https://a/b",
|
||||
" http://a",
|
||||
" http://a/b",
|
||||
" ftp://a/b",
|
||||
"file://a/b",
|
||||
"ssh://a/b",
|
||||
"foo://a/b",
|
||||
"git//a/b",
|
||||
"git+https://a/b",
|
||||
"https+git://a/b",
|
||||
"git+http://a/b",
|
||||
"http+git://a/b",
|
||||
"",
|
||||
" ",
|
||||
"-e ",
|
||||
"-e x",
|
||||
"-e http://a",
|
||||
"-e http://a/b",
|
||||
"-e https://a",
|
||||
"-e https://a/b",
|
||||
"-e file://a/b",
|
||||
],
|
||||
)
|
||||
def test_not_supported(line):
|
||||
assert not RequirementsTranslator.is_supported_link(line)
|
||||
33
tests/package/test_pytorch_map.py
Normal file
33
tests/package/test_pytorch_map.py
Normal file
@@ -0,0 +1,33 @@
|
||||
import attr
|
||||
import pytest
|
||||
import requests
|
||||
from furl import furl
|
||||
|
||||
import six
|
||||
from trains_agent.helper.package.pytorch import PytorchRequirement
|
||||
|
||||
|
||||
@attr.s
|
||||
class PytorchURLWheel(object):
|
||||
os = attr.ib()
|
||||
cuda = attr.ib()
|
||||
python = attr.ib()
|
||||
pytorch = attr.ib()
|
||||
url = attr.ib()
|
||||
|
||||
|
||||
wheels = [
|
||||
PytorchURLWheel(os=os, cuda=cuda, python=python, pytorch=pytorch_version, url=url)
|
||||
for os, os_d in PytorchRequirement.MAP.items()
|
||||
for cuda, cuda_d in os_d.items()
|
||||
if isinstance(cuda_d, dict)
|
||||
for python, python_d in cuda_d.items()
|
||||
if isinstance(python_d, dict)
|
||||
for pytorch_version, url in python_d.items()
|
||||
if isinstance(url, six.string_types) and furl(url).scheme
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize('wheel', wheels, ids=[','.join(map(str, attr.astuple(wheel))) for wheel in wheels])
|
||||
def test_map(wheel):
|
||||
assert requests.head(wheel.url).ok
|
||||
75
tests/package/test_repo_url_auth.py
Normal file
75
tests/package/test_repo_url_auth.py
Normal file
@@ -0,0 +1,75 @@
|
||||
import pytest
|
||||
|
||||
from trains_agent.helper.repo import Git
|
||||
|
||||
NO_CHANGE = object()
|
||||
|
||||
|
||||
def param(url, expected, user=False, password=False):
|
||||
"""
|
||||
Helper function for creating parametrization arguments.
|
||||
:param url: input url
|
||||
:param expected: expected output URL or NO_CHANGE if the same as input URL
|
||||
:param user: Add `agent.git_user=USER` to config
|
||||
:param password: Add `agent.git_password=PASSWORD` to config
|
||||
"""
|
||||
expected_repr = "NO_CHANGE" if expected is NO_CHANGE else None
|
||||
user = "USER" if user else None
|
||||
password = "PASSWORD" if password else None
|
||||
return pytest.param(
|
||||
url,
|
||||
expected,
|
||||
user,
|
||||
password,
|
||||
id="-".join(filter(None, (url, user, password, expected_repr))),
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"url,expected,user,password",
|
||||
[
|
||||
param("https://bitbucket.org/company/repo", NO_CHANGE),
|
||||
param("https://bitbucket.org/company/repo", NO_CHANGE, user=True),
|
||||
param("https://bitbucket.org/company/repo", NO_CHANGE, password=True),
|
||||
param(
|
||||
"https://bitbucket.org/company/repo", NO_CHANGE, user=True, password=True
|
||||
),
|
||||
param("https://user@bitbucket.org/company/repo", NO_CHANGE),
|
||||
param("https://user@bitbucket.org/company/repo", NO_CHANGE, user=True),
|
||||
param("https://user@bitbucket.org/company/repo", NO_CHANGE, password=True),
|
||||
param(
|
||||
"https://user@bitbucket.org/company/repo",
|
||||
"https://USER:PASSWORD@bitbucket.org/company/repo",
|
||||
user=True,
|
||||
password=True,
|
||||
),
|
||||
param("https://user:password@bitbucket.org/company/repo", NO_CHANGE),
|
||||
param("https://user:password@bitbucket.org/company/repo", NO_CHANGE, user=True),
|
||||
param(
|
||||
"https://user:password@bitbucket.org/company/repo", NO_CHANGE, password=True
|
||||
),
|
||||
param(
|
||||
"https://user:password@bitbucket.org/company/repo",
|
||||
NO_CHANGE,
|
||||
user=True,
|
||||
password=True,
|
||||
),
|
||||
param("ssh://git@bitbucket.org/company/repo", NO_CHANGE),
|
||||
param("ssh://git@bitbucket.org/company/repo", NO_CHANGE, user=True),
|
||||
param("ssh://git@bitbucket.org/company/repo", NO_CHANGE, password=True),
|
||||
param(
|
||||
"ssh://git@bitbucket.org/company/repo", NO_CHANGE, user=True, password=True
|
||||
),
|
||||
param("git@bitbucket.org:company/repo.git", NO_CHANGE),
|
||||
param("git@bitbucket.org:company/repo.git", NO_CHANGE, user=True),
|
||||
param("git@bitbucket.org:company/repo.git", NO_CHANGE, password=True),
|
||||
param(
|
||||
"git@bitbucket.org:company/repo.git", NO_CHANGE, user=True, password=True
|
||||
),
|
||||
],
|
||||
)
|
||||
def test(url, user, password, expected):
|
||||
config = {"agent": {"git_user": user, "git_pass": password}}
|
||||
result = Git.add_auth(config, url)
|
||||
expected = result if expected is NO_CHANGE else expected
|
||||
assert result == expected
|
||||
254
tests/package/test_task_script.py
Normal file
254
tests/package/test_task_script.py
Normal file
@@ -0,0 +1,254 @@
|
||||
"""
|
||||
Test handling of jupyter notebook tasks.
|
||||
Logging is enabled in `trains_agent/tests/pytest.ini`. Search for `pytest live logging` for more info.
|
||||
"""
|
||||
import logging
|
||||
import re
|
||||
import select
|
||||
import subprocess
|
||||
import time
|
||||
from contextlib import contextmanager
|
||||
from typing import Iterator, ContextManager, Sequence, IO, Text
|
||||
from uuid import uuid4
|
||||
|
||||
from trains_agent.backend_api.services.tasks import Script
|
||||
from trains_agent.backend_api.session.client import APIClient
|
||||
from pathlib2 import Path
|
||||
from pytest import fixture
|
||||
|
||||
from trains_agent.helper.process import Argv
|
||||
|
||||
logging.getLogger("urllib3").setLevel(logging.CRITICAL)
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
DEFAULT_TASK_ARGS = {"type": "testing", "name": "test", "input": {"view": {}}}
|
||||
HERE = Path(__file__).resolve().parent
|
||||
SHORT_TIMEOUT = 30
|
||||
|
||||
|
||||
@fixture(scope="session")
|
||||
def client():
|
||||
return APIClient(api_version="2.2")
|
||||
|
||||
|
||||
@contextmanager
|
||||
def create_task(client, **params):
|
||||
"""
|
||||
Create task in backend
|
||||
"""
|
||||
log.info("creating new task")
|
||||
task = client.tasks.create(**params)
|
||||
try:
|
||||
yield task
|
||||
finally:
|
||||
log.info("deleting task, id=%s", task.id)
|
||||
task.delete(force=True)
|
||||
|
||||
|
||||
def select_read(file_obj, timeout):
|
||||
return select.select([file_obj], [], [], timeout)[0]
|
||||
|
||||
|
||||
def run_task(task):
|
||||
return Argv("trains_agent", "--debug", "worker", "execute", "--id", task.id)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def iterate_output(timeout, command):
|
||||
# type: (int, Argv) -> ContextManager[Iterator[Text]]
|
||||
"""
|
||||
Run `command` in a subprocess and return a contextmanager of an iterator
|
||||
over its output's lines. If `timeout` seconds have passed, iterator ends and
|
||||
the process is killed.
|
||||
:param timeout: maximum amount of time to wait for command to end
|
||||
:param command: command to run
|
||||
"""
|
||||
log.info("running: %s", command)
|
||||
process = command.call_subprocess(
|
||||
subprocess.Popen, stdout=subprocess.PIPE, stderr=subprocess.STDOUT
|
||||
)
|
||||
try:
|
||||
yield _iterate_output(timeout, process)
|
||||
finally:
|
||||
status = process.poll()
|
||||
if status is not None:
|
||||
log.info("command %s terminated, status: %s", command, status)
|
||||
else:
|
||||
log.info("killing command %s", command)
|
||||
process.kill()
|
||||
|
||||
|
||||
def _iterate_output(timeout, process):
|
||||
# type: (int, subprocess.Popen) -> Iterator[Text]
|
||||
"""
|
||||
Return an iterator over process's output lines.
|
||||
over its output's lines.
|
||||
If `timeout` seconds have passed, iterator ends and the process is killed.
|
||||
:param timeout: maximum amount of time to wait for command to end
|
||||
:param process: process to iterate over its output lines
|
||||
"""
|
||||
start = time.time()
|
||||
exit_loop = [] # type: Sequence[IO]
|
||||
|
||||
def loop_helper(file_obj):
|
||||
# type: (IO) -> Sequence[IO]
|
||||
diff = timeout - (time.time() - start)
|
||||
if diff <= 0:
|
||||
return exit_loop
|
||||
return select_read(file_obj, timeout=diff)
|
||||
|
||||
buffer = ""
|
||||
|
||||
for output, in iter(lambda: loop_helper(process.stdout), exit_loop):
|
||||
try:
|
||||
added = output.read(1024).decode("utf8")
|
||||
except EOFError:
|
||||
if buffer:
|
||||
yield buffer
|
||||
return
|
||||
buffer += added
|
||||
lines = buffer.split("\n", 1)
|
||||
|
||||
while len(lines) > 1:
|
||||
line, buffer = lines
|
||||
log.debug("--- %s", line)
|
||||
yield line
|
||||
lines = buffer.split("\n", 1)
|
||||
|
||||
|
||||
def search_lines(lines, search_for, error):
|
||||
# type: (Iterator[Text], Text, Text) -> None
|
||||
"""
|
||||
Fail test if `search_for` string appears nowhere in `lines`.
|
||||
Consumes `lines` up to point where `search_for` is found.
|
||||
:param lines: lines to search in
|
||||
:param search_for: string to search lines for
|
||||
:param error: error to show if not found
|
||||
"""
|
||||
for line in lines:
|
||||
if search_for in line:
|
||||
break
|
||||
else:
|
||||
assert False, error
|
||||
|
||||
|
||||
def search_lines_pattern(lines, pattern, error):
|
||||
# type: (Iterator[Text], Text, Text) -> None
|
||||
"""
|
||||
Like `search_lines` but searches for a pattern.
|
||||
:param lines: lines to search in
|
||||
:param pattern: pattern to search lines for
|
||||
:param error: error to show if not found
|
||||
"""
|
||||
for line in lines:
|
||||
if re.search(pattern, line):
|
||||
break
|
||||
else:
|
||||
assert False, error
|
||||
|
||||
|
||||
def test_entry_point_warning(client):
|
||||
"""
|
||||
non-empty script.entry_point should output a warning
|
||||
"""
|
||||
with create_task(
|
||||
client,
|
||||
script=Script(diff="print('hello')", entry_point="foo.py", repository=""),
|
||||
**DEFAULT_TASK_ARGS
|
||||
) as task, iterate_output(SHORT_TIMEOUT, run_task(task)) as output:
|
||||
for line in output:
|
||||
if "found non-empty script.entry_point" in line:
|
||||
break
|
||||
else:
|
||||
assert False, "did not find warning in output"
|
||||
|
||||
|
||||
def test_run_no_dirs(client):
|
||||
"""
|
||||
The arbitrary `code` directory should be selected when there is no `script.repository`
|
||||
"""
|
||||
uuid = uuid4().hex
|
||||
script = "print('{}')".format(uuid)
|
||||
with create_task(
|
||||
client,
|
||||
script=Script(diff=script, entry_point="", repository="", working_dir=""),
|
||||
**DEFAULT_TASK_ARGS
|
||||
) as task, iterate_output(SHORT_TIMEOUT, run_task(task)) as output:
|
||||
search_lines(
|
||||
output,
|
||||
"found literal script",
|
||||
"task was not recognized as a literal script",
|
||||
)
|
||||
search_lines_pattern(
|
||||
output,
|
||||
r"selected execution directory:.*code",
|
||||
r"did not selected empty `code` dir as execution dir",
|
||||
)
|
||||
search_lines(output, uuid, "did not find uuid {!r} in output".format(uuid))
|
||||
|
||||
|
||||
def test_run_working_dir(client):
|
||||
"""
|
||||
Literal script tasks should respect `working_dir`
|
||||
"""
|
||||
uuid = uuid4().hex
|
||||
script = "print('{}')".format(uuid)
|
||||
with create_task(
|
||||
client,
|
||||
script=Script(
|
||||
diff=script,
|
||||
entry_point="",
|
||||
repository="git@bitbucket.org:seematics/roee_test_git.git",
|
||||
working_dir="space dir",
|
||||
),
|
||||
**DEFAULT_TASK_ARGS
|
||||
) as task, iterate_output(120, run_task(task)) as output:
|
||||
search_lines(
|
||||
output,
|
||||
"found literal script",
|
||||
"task was not recognized as a literal script",
|
||||
)
|
||||
search_lines_pattern(
|
||||
output,
|
||||
r"selected execution directory:.*space dir",
|
||||
r"did not selected working_dir as set in execution_info",
|
||||
)
|
||||
search_lines(output, uuid, "did not find uuid {!r} in output".format(uuid))
|
||||
|
||||
|
||||
def test_regular_task(client):
|
||||
"""
|
||||
Test a plain old task
|
||||
"""
|
||||
with create_task(
|
||||
client,
|
||||
script=Script(
|
||||
entry_point="noop.py",
|
||||
repository="git@bitbucket.org:seematics/roee_test_git.git",
|
||||
),
|
||||
**DEFAULT_TASK_ARGS
|
||||
) as task, iterate_output(SHORT_TIMEOUT, run_task(task)) as output:
|
||||
message = "Done"
|
||||
search_lines(
|
||||
output, message, "did not reach dummy output message {}".format(message)
|
||||
)
|
||||
|
||||
|
||||
def test_regular_task_nested(client):
|
||||
"""
|
||||
`entry_point` should be relative to `working_dir` if present
|
||||
"""
|
||||
with create_task(
|
||||
client,
|
||||
script=Script(
|
||||
entry_point="noop_nested.py",
|
||||
working_dir="no_reqs",
|
||||
repository="git@bitbucket.org:seematics/roee_test_git.git",
|
||||
),
|
||||
**DEFAULT_TASK_ARGS
|
||||
) as task, iterate_output(SHORT_TIMEOUT, run_task(task)) as output:
|
||||
message = "Done"
|
||||
search_lines(
|
||||
output, message, "did not reach dummy output message {}".format(message)
|
||||
)
|
||||
11
tests/pytest.ini
Normal file
11
tests/pytest.ini
Normal file
@@ -0,0 +1,11 @@
|
||||
[pytest]
|
||||
python_files=*.py
|
||||
script_launch_mode = subprocess
|
||||
env =
|
||||
PYTHONIOENCODING=utf-8
|
||||
log_cli = 1
|
||||
log_cli_level = DEBUG
|
||||
log_cli_format = [%(name)s] [%(asctime)s] [%(levelname)8s] %(message)s
|
||||
log_print = False
|
||||
log_cli_date_format=%Y-%m-%d %H:%M:%S
|
||||
addopts = -p no:warnings
|
||||
2
tests/requirements.txt
Normal file
2
tests/requirements.txt
Normal file
@@ -0,0 +1,2 @@
|
||||
pytest
|
||||
-r ../requirements.txt
|
||||
0
tests/requirements/__init__.py
Normal file
0
tests/requirements/__init__.py
Normal file
13
tests/requirements/conftest.py
Normal file
13
tests/requirements/conftest.py
Normal file
@@ -0,0 +1,13 @@
|
||||
import os
|
||||
|
||||
|
||||
def pytest_addoption(parser):
|
||||
parser.addoption('--cpu', action='store_true')
|
||||
|
||||
|
||||
def pytest_configure(config):
|
||||
if not config.option.cpu:
|
||||
return
|
||||
os.environ['PATH'] = ':'.join(p for p in os.environ['PATH'].split(':') if 'cuda' not in p)
|
||||
os.environ['CUDA_VERSION'] = ''
|
||||
os.environ['CUDNN_VERSION'] = ''
|
||||
350
tests/requirements/requirements_substitution.py
Normal file
350
tests/requirements/requirements_substitution.py
Normal file
@@ -0,0 +1,350 @@
|
||||
from __future__ import unicode_literals
|
||||
|
||||
import re
|
||||
import subprocess
|
||||
from itertools import chain
|
||||
from os import path
|
||||
from os.path import sep
|
||||
from sys import platform as sys_platform
|
||||
|
||||
import pytest
|
||||
import requirements
|
||||
|
||||
from trains_agent.commands.worker import Worker
|
||||
from trains_agent.helper.package.pytorch import PytorchRequirement
|
||||
from trains_agent.helper.package.requirements import RequirementsManager, \
|
||||
RequirementSubstitution, MarkerRequirement
|
||||
from trains_agent.helper.process import get_bash_output
|
||||
from trains_agent.session import Session
|
||||
|
||||
_cuda_based_packages_hack = ('seematics.caffe', 'lightnet')
|
||||
|
||||
|
||||
def old_get_suffix(session):
|
||||
cuda_version = session.config['agent.cuda_version']
|
||||
cudnn_version = session.config['agent.cudnn_version']
|
||||
if cuda_version and cudnn_version:
|
||||
nvcc_ver = cuda_version.strip()
|
||||
cudnn_ver = cudnn_version.strip()
|
||||
else:
|
||||
if sys_platform == 'win32':
|
||||
nvcc_ver = subprocess.check_output('nvcc --version'.split()).decode('utf-8'). \
|
||||
replace('\r', '').split('\n')
|
||||
else:
|
||||
nvcc_ver = subprocess.check_output(
|
||||
'nvcc --version'.split()).decode('utf-8').split('\n')
|
||||
nvcc_ver = [l for l in nvcc_ver if 'release ' in l]
|
||||
nvcc_ver = nvcc_ver[0][nvcc_ver[0].find(
|
||||
'release ') + len('release '):][:3]
|
||||
nvcc_ver = str(int(float(nvcc_ver) * 10))
|
||||
if sys_platform == 'win32':
|
||||
cuda_lib = subprocess.check_output('where nvcc'.split()).decode('utf-8'). \
|
||||
replace('\r', '').split('\n')[0]
|
||||
cudnn_h = sep.join(cuda_lib.split(
|
||||
sep)[:-2] + ['include', 'cudnn.h'])
|
||||
else:
|
||||
cuda_lib = subprocess.check_output(
|
||||
'which nvcc'.split()).decode('utf-8').split('\n')[0]
|
||||
cudnn_h = path.join(
|
||||
sep, *(cuda_lib.split(sep)[:-2] + ['include', 'cudnn.h']))
|
||||
cudnn_mj, cudnn_mi = None, None
|
||||
for l in open(cudnn_h, 'r'):
|
||||
if 'CUDNN_MAJOR' in l:
|
||||
cudnn_mj = l.split()[-1]
|
||||
if 'CUDNN_MINOR' in l:
|
||||
cudnn_mi = l.split()[-1]
|
||||
if cudnn_mj and cudnn_mi:
|
||||
break
|
||||
cudnn_ver = cudnn_mj + ('0' if not cudnn_mi else cudnn_mi)
|
||||
# build cuda + cudnn version suffix
|
||||
# make sure these are integers, someone else will catch the exception
|
||||
pkg_suffix_ver = '.post' + \
|
||||
str(int(nvcc_ver)) + '.dev' + str(int(cudnn_ver))
|
||||
return pkg_suffix_ver, nvcc_ver, cudnn_ver
|
||||
|
||||
|
||||
def old_replace(session, line):
|
||||
try:
|
||||
cuda_ver_suffix, cuda_ver, cuda_cudnn_ver = old_get_suffix(session)
|
||||
except Exception:
|
||||
return line
|
||||
if line.lstrip().startswith('#'):
|
||||
return line
|
||||
for package_name in _cuda_based_packages_hack:
|
||||
if package_name not in line:
|
||||
continue
|
||||
try:
|
||||
line_lstrip = line.lstrip()
|
||||
if line_lstrip.startswith('http://') or line_lstrip.startswith('https://'):
|
||||
pos = line.find(package_name) + len(package_name)
|
||||
# patch line with specific version
|
||||
line = line[:pos] + \
|
||||
line[pos:].replace('-cp', cuda_ver_suffix + '-cp', 1)
|
||||
else:
|
||||
# this is a pypi package
|
||||
tokens = line.replace('=', ' ').replace('<', ' ').replace('>', ' ').replace(';', ' '). \
|
||||
replace('!', ' ').split()
|
||||
if package_name != tokens[0]:
|
||||
# how did we get here, probably a mistake
|
||||
found_cuda_based_package = False
|
||||
continue
|
||||
|
||||
version_number = None
|
||||
if len(tokens) > 1:
|
||||
# get the package version info
|
||||
test_version_number = tokens[1]
|
||||
# check if we have a valid version, i.e. does not contain post/dev
|
||||
version_number = '.'.join([v for v in test_version_number.split('.')
|
||||
if v and '0' <= v[0] <= '9'])
|
||||
if version_number != test_version_number:
|
||||
raise ValueError()
|
||||
|
||||
# we have no version, but we have to have one
|
||||
if not version_number:
|
||||
# get the latest version from the extra index list
|
||||
pip_search_cmd = ['pip', 'search']
|
||||
if Worker._pip_extra_index_url:
|
||||
pip_search_cmd.extend(
|
||||
chain.from_iterable(('-i', x) for x in Worker._pip_extra_index_url))
|
||||
pip_search_cmd += [package_name]
|
||||
pip_search_output = get_bash_output(
|
||||
' '.join(pip_search_cmd), strip=True)
|
||||
version_number = pip_search_output.split(package_name)[1]
|
||||
version_number = version_number.replace(
|
||||
'(', ' ').replace(')', ' ').split()[0]
|
||||
version_number = '.'.join([v for v in version_number.split('.')
|
||||
if v and '0' <= v[0] <= '9'])
|
||||
if not version_number:
|
||||
# somewhere along the way we failed
|
||||
raise ValueError()
|
||||
|
||||
package_name_version = package_name + '==' + version_number + cuda_ver_suffix
|
||||
if version_number in line:
|
||||
# make sure we have the specific version not >=
|
||||
tokens = line.split(';')
|
||||
line = ';'.join([package_name_version] + tokens[1:])
|
||||
else:
|
||||
# add version to the package_name
|
||||
line = line.replace(package_name, package_name_version, 1)
|
||||
|
||||
# print('pip install %s using CUDA v%s CuDNN v%s' % (package_name, cuda_ver, cuda_cudnn_ver))
|
||||
except ValueError:
|
||||
pass
|
||||
# print('Warning! could not find installed CUDA/CuDNN version for %s, '
|
||||
# 'using original requirements line: %s' % (package_name, line))
|
||||
# add the current line into the cuda requirements list
|
||||
return line
|
||||
|
||||
|
||||
win_condition = 'sys_platform != "win_32"'
|
||||
versions = ('1', '1.4', '1.4.9', '1.5.3.dev0',
|
||||
'1.5.3.dev3', '43.1.2.dev0.post1')
|
||||
|
||||
|
||||
def normalize(result):
|
||||
return result and re.sub(' ?; ?', ';', result)
|
||||
|
||||
|
||||
def parse_one(requirement):
|
||||
return MarkerRequirement(next(requirements.parse(requirement)))
|
||||
|
||||
|
||||
def compare(manager, current_session, pytest_config, requirement, expected):
|
||||
try:
|
||||
res1 = normalize(manager._replace_one(parse_one(requirement)))
|
||||
except ValueError:
|
||||
res1 = None
|
||||
res2 = old_replace(current_session, requirement)
|
||||
if res2 == requirement:
|
||||
res2 = None
|
||||
res2 = normalize(res2)
|
||||
expected = normalize(expected)
|
||||
if pytest_config.option.cpu:
|
||||
assert res1 is None
|
||||
else:
|
||||
assert res1 == expected
|
||||
if requirement not in FAILURES:
|
||||
assert res1 == res2
|
||||
return res1
|
||||
|
||||
|
||||
ARG_VERSIONED = pytest.mark.parametrize('arg', (
|
||||
'nothing{op}{version}{extra}',
|
||||
'something_else{op}{version}{extra}',
|
||||
'something-else{op}{version}{extra}',
|
||||
'something.else{op}{version}{extra}',
|
||||
'seematics.caffe{op}{version}{extra}',
|
||||
'lightnet{op}{version}{extra}',
|
||||
))
|
||||
OP = pytest.mark.parametrize('op', ('==', '<=', '>='))
|
||||
VERSION = pytest.mark.parametrize('version', versions)
|
||||
EXTRA = pytest.mark.parametrize('extra', ('', ' ; ' + win_condition))
|
||||
ARG_PLAIN = pytest.mark.parametrize('arg', (
|
||||
'nothing',
|
||||
'something_else',
|
||||
'something-else',
|
||||
'something.else',
|
||||
'seematics.caffe',
|
||||
'lightnet',
|
||||
'https://s3.amazonaws.com/seematics-pip/public/windows64bit/static/seematics.caffe-1.0.3-cp35-cp35m-win_amd64.whl',
|
||||
'https://s3.amazonaws.com/seematics-pip/public/windows64bit/static/seematics.caffe-1.0.3-cp27-cp27m-win_amd64.whl',
|
||||
'https://s3.amazonaws.com/seematics-pip/public/static/seematics.caffe-1.0.3-cp35-cp35m-linux_x86_64.whl',
|
||||
'https://s3.amazonaws.com/seematics-pip/public/static/seematics.caffe-1.0.3-cp27-cp27mu-linux_x86_64.whl',
|
||||
'https://s3.amazonaws.com/seematics-pip/public/seematics.config-1.0.2-py2.py3-none-any.whl',
|
||||
'https://s3.amazonaws.com/seematics-pip/public/seematics.api-1.0.2-py2.py3-none-any.whl',
|
||||
'https://s3.amazonaws.com/seematics-pip/public/seematics.sdk-1.1.0-py2.py3-none-any.whl',
|
||||
))
|
||||
|
||||
|
||||
@ARG_VERSIONED
|
||||
@OP
|
||||
@VERSION
|
||||
@EXTRA
|
||||
def test_with_version(manager, current_session, pytestconfig, arg, op, version, extra):
|
||||
requirement = arg.format(**locals())
|
||||
# expected = EXPECTED.get((arg, op, version))
|
||||
expected = get_expected(current_session, (arg, op, version))
|
||||
expected = expected and expected + extra
|
||||
compare(manager, current_session, pytestconfig, requirement, expected)
|
||||
|
||||
|
||||
@ARG_PLAIN
|
||||
@EXTRA
|
||||
def test_plain(arg, current_session, pytestconfig, extra, manager):
|
||||
# expected = EXPECTED.get(arg)
|
||||
expected = get_expected(current_session, arg)
|
||||
expected = expected and expected + extra
|
||||
compare(manager, current_session, pytestconfig, arg + extra, expected)
|
||||
|
||||
|
||||
def get_expected(session, key):
|
||||
result = EXPECTED.get(key)
|
||||
cuda_version, cudnn_version = session.config['agent.cuda_version'], session.config['agent.cudnn_version']
|
||||
suffix = '.post{cuda_version}.dev{cudnn_version}'.format(**locals()) \
|
||||
if (cuda_version and cudnn_version) \
|
||||
else ''
|
||||
return result and result.format(suffix=suffix)
|
||||
|
||||
|
||||
@ARG_VERSIONED
|
||||
@OP
|
||||
@VERSION
|
||||
@EXTRA
|
||||
def test_str_versioned(arg, op, version, extra):
|
||||
requirement = arg.format(**locals())
|
||||
assert normalize(str(parse_one(requirement))) == normalize(requirement)
|
||||
|
||||
|
||||
@ARG_PLAIN
|
||||
@EXTRA
|
||||
def test_str_plain(arg, extra, manager):
|
||||
requirement = arg.format(**locals())
|
||||
assert normalize(str(parse_one(requirement))) == normalize(requirement)
|
||||
|
||||
|
||||
@pytest.fixture(scope='session')
|
||||
def current_session(pytestconfig):
|
||||
session = Session()
|
||||
if not pytestconfig.option.cpu:
|
||||
return session
|
||||
session.config['agent.cuda_version'] = None
|
||||
session.config['agent.cudnn_version'] = None
|
||||
return session
|
||||
|
||||
|
||||
@pytest.fixture(scope='session')
|
||||
def manager(current_session):
|
||||
manager = RequirementsManager(current_session.config)
|
||||
for requirement in (PytorchRequirement, ):
|
||||
manager.register(requirement)
|
||||
return manager
|
||||
|
||||
|
||||
SPECS = (
|
||||
# plain
|
||||
'',
|
||||
|
||||
# greater than
|
||||
'>=0', '>0.1', '>=0.1', '>0.1.2', '>=0.1.2', '>0.1.2.post30', '>=0.1.3.post30', '>0.post30.dev40',
|
||||
'>=0.post30.dev40', '>0.post30-dev40', '>=0.post30-dev40', '>0.0', '>=0.0', '>0.3', '>=0.3', '>=0.4',
|
||||
|
||||
# smaller than
|
||||
'<4', '<4.0', '<4.1.0', '<4.1.0.post80.dev60', '<4.1.0-post80.dev60', '<3', '<2.0', '<1.0.3', '<=4', '<=4.0',
|
||||
'<=4.1.0', '<=4.1.0.post80.dev60', '<=4.1.0-post80.dev60', '<=3', '<=2.0', '<=1.0.3',
|
||||
|
||||
# equals
|
||||
'==0.4.0',
|
||||
|
||||
# equals and
|
||||
'==0.4.0,>=0', '==0.4.0,>=0.1', '==0.4.0,>0.1', '==0.4.0,>=0.1.2', '==0.4.0,>0.1.2', '==0.4.0,<4', '==0.4.0,<=4',
|
||||
'==0.4.0,<4.0', '==0.4.0,<=4.0', '==0.4.0,<4.0.2',
|
||||
|
||||
# smaller and greater
|
||||
'>=0,<4', '>=0,<4.0', '>0.1,<1', '>0.1,<1.1.2', '>0.1,<1.1.2', '>=0,<=4', '>=0,<4.0',
|
||||
'>=0.1,<1', '>0.1,<=1.1.2', '>=0.1,<1.1.2',
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('package_manager', ('pip', 'conda'))
|
||||
@pytest.mark.parametrize('os', ('linux', 'windows', 'macos'))
|
||||
@pytest.mark.parametrize('cuda', ('80', '90', '91'))
|
||||
@pytest.mark.parametrize('python', ('2.7', '3.5', '3.6'))
|
||||
@pytest.mark.parametrize('spec', SPECS)
|
||||
@pytest.mark.parametrize('condition', ('', ' ; ' + win_condition))
|
||||
def test_pytorch_success(manager, package_manager, os, cuda, python, spec, condition):
|
||||
pytorch_handler = manager.handlers[-1]
|
||||
pytorch_handler.package_manager = package_manager
|
||||
pytorch_handler.os = os
|
||||
pytorch_handler.cuda = cuda_ver = 'cuda{}'.format(cuda)
|
||||
pytorch_handler.python = python_ver = 'python{}'.format(python)
|
||||
req = 'torch{}{}'.format(spec, condition)
|
||||
expected = pytorch_handler.MAP[package_manager][os][cuda_ver][python_ver]
|
||||
if isinstance(expected, Exception):
|
||||
with pytest.raises(type(expected)):
|
||||
manager._replace_one(parse_one(req))
|
||||
else:
|
||||
expected = expected['0.4.0']
|
||||
result = manager._replace_one(parse_one(req))
|
||||
assert result == expected
|
||||
|
||||
|
||||
get_pip_version = RequirementSubstitution.get_pip_version
|
||||
|
||||
EXPECTED = {
|
||||
'https://s3.amazonaws.com/seematics-pip/public/windows64bit/static/seematics.caffe-1.0.3-cp35-cp35m-win_amd64.whl':
|
||||
'https://s3.amazonaws.com/seematics-pip/public/windows64bit/static/seematics.caffe-1.0.3{suffix}-cp35'
|
||||
'-cp35m-win_amd64.whl',
|
||||
'https://s3.amazonaws.com/seematics-pip/public/windows64bit/static/seematics.caffe-1.0.3-cp27-cp27m-win_amd64.whl':
|
||||
'https://s3.amazonaws.com/seematics-pip/public/windows64bit/static/seematics.caffe-1.0.3{suffix}-cp27'
|
||||
'-cp27m-win_amd64.whl',
|
||||
'https://s3.amazonaws.com/seematics-pip/public/static/seematics.caffe-1.0.3-cp35-cp35m-linux_x86_64.whl':
|
||||
'https://s3.amazonaws.com/seematics-pip/public/static/seematics.caffe-1.0.3{suffix}-cp35-cp35m'
|
||||
'-linux_x86_64.whl',
|
||||
'https://s3.amazonaws.com/seematics-pip/public/static/seematics.caffe-1.0.3-cp27-cp27mu-linux_x86_64.whl':
|
||||
'https://s3.amazonaws.com/seematics-pip/public/static/seematics.caffe-1.0.3{suffix}-cp27-cp27mu'
|
||||
'-linux_x86_64.whl',
|
||||
'seematics.caffe': 'seematics.caffe=={}{{suffix}}'.format(get_pip_version('seematics.caffe')),
|
||||
'lightnet': 'lightnet=={}{{suffix}}'.format(get_pip_version('lightnet')),
|
||||
('seematics.caffe{op}{version}{extra}', '<=', '1'): 'seematics.caffe==1{suffix}',
|
||||
('seematics.caffe{op}{version}{extra}', '<=', '1.4'): 'seematics.caffe==1.4{suffix}',
|
||||
('seematics.caffe{op}{version}{extra}', '<=', '1.4.9'): 'seematics.caffe==1.4.9{suffix}',
|
||||
('seematics.caffe{op}{version}{extra}', '==', '1'): 'seematics.caffe==1{suffix}',
|
||||
('seematics.caffe{op}{version}{extra}', '==', '1.4'): 'seematics.caffe==1.4{suffix}',
|
||||
('seematics.caffe{op}{version}{extra}', '==', '1.4.9'): 'seematics.caffe==1.4.9{suffix}',
|
||||
('seematics.caffe{op}{version}{extra}', '>=', '1'): 'seematics.caffe==1{suffix}',
|
||||
('seematics.caffe{op}{version}{extra}', '>=', '1.4'): 'seematics.caffe==1.4{suffix}',
|
||||
('seematics.caffe{op}{version}{extra}', '>=', '1.4.9'): 'seematics.caffe==1.4.9{suffix}',
|
||||
('lightnet{op}{version}{extra}', '<=', '1'): 'lightnet==1{suffix}',
|
||||
('lightnet{op}{version}{extra}', '<=', '1.4'): 'lightnet==1.4{suffix}',
|
||||
('lightnet{op}{version}{extra}', '<=', '1.4.9'): 'lightnet==1.4.9{suffix}',
|
||||
('lightnet{op}{version}{extra}', '==', '1'): 'lightnet==1{suffix}',
|
||||
('lightnet{op}{version}{extra}', '==', '1.4'): 'lightnet==1.4{suffix}',
|
||||
('lightnet{op}{version}{extra}', '==', '1.4.9'): 'lightnet==1.4.9{suffix}',
|
||||
('lightnet{op}{version}{extra}', '>=', '1'): 'lightnet==1{suffix}',
|
||||
('lightnet{op}{version}{extra}', '>=', '1.4'): 'lightnet==1.4{suffix}',
|
||||
('lightnet{op}{version}{extra}', '>=', '1.4.9'): 'lightnet==1.4.9{suffix}',
|
||||
}
|
||||
FAILURES = {
|
||||
'seematics.caffe ; {}'.format(win_condition),
|
||||
'lightnet ; {}'.format(win_condition)
|
||||
}
|
||||
0
tests/scripts/__init__.py
Normal file
0
tests/scripts/__init__.py
Normal file
7
tests/scripts/python2-test.py
Normal file
7
tests/scripts/python2-test.py
Normal file
@@ -0,0 +1,7 @@
|
||||
def main():
|
||||
assert 1 / 2 == 0
|
||||
print('success')
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
8
tests/scripts/python3-test.py
Normal file
8
tests/scripts/python3-test.py
Normal file
@@ -0,0 +1,8 @@
|
||||
def main():
|
||||
if not (1 / 2 == 0.5):
|
||||
raise ValueError('failure')
|
||||
print('success')
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user