Fix support for spaces in docker arguments (issue #358)

This commit is contained in:
allegroai 2021-05-19 15:20:03 +03:00
parent 4f7407084d
commit 1e795beec8

View File

@ -11,6 +11,7 @@ import subprocess
import sys import sys
import shutil import shutil
import traceback import traceback
import shlex
from collections import defaultdict from collections import defaultdict
from copy import deepcopy, copy from copy import deepcopy, copy
from datetime import datetime from datetime import datetime
@ -221,6 +222,9 @@ def get_task(session, task_id, *args, **kwargs):
def get_task_container(session, task_id): def get_task_container(session, task_id):
"""
Returns dict with Task docker container setup {container: '', arguments: '', setup_shell_script: ''}
"""
if session.check_min_api_version("2.13"): if session.check_min_api_version("2.13"):
result = session.send_request( result = session.send_request(
service='tasks', service='tasks',
@ -233,12 +237,12 @@ def get_task_container(session, task_id):
try: try:
container = result.json()['data']['tasks'][0]['container'] if result.ok else {} container = result.json()['data']['tasks'][0]['container'] if result.ok else {}
if container.get('arguments'): if container.get('arguments'):
container['arguments'] = str(container.get('arguments')).split(' ') container['arguments'] = shlex.split(str(container.get('arguments')).strip())
except (ValueError, TypeError): except (ValueError, TypeError):
container = {} container = {}
else: else:
response = get_task(session, task_id, only_fields=["execution.docker_cmd"]) response = get_task(session, task_id, only_fields=["execution.docker_cmd"])
task_docker_cmd_parts = str(response.execution.docker_cmd or '').strip().split(' ') task_docker_cmd_parts = shlex.split(str(response.execution.docker_cmd or '').strip())
try: try:
container = dict( container = dict(
container=task_docker_cmd_parts[0], container=task_docker_cmd_parts[0],
@ -251,11 +255,14 @@ def get_task_container(session, task_id):
def set_task_container(session, task_id, docker_image=None, docker_arguments=None, docker_bash_script=None): def set_task_container(session, task_id, docker_image=None, docker_arguments=None, docker_bash_script=None):
if docker_arguments and isinstance(docker_arguments, str):
docker_arguments = [docker_arguments]
if session.check_min_api_version("2.13"): if session.check_min_api_version("2.13"):
container = dict( container = dict(
image=docker_image or None, image=docker_image or '',
arguments=' '.join(docker_arguments) if docker_arguments else None, arguments=' '.join(docker_arguments) if docker_arguments else '',
setup_shell_script=docker_bash_script or None, setup_shell_script=docker_bash_script or '',
) )
result = session.send_request( result = session.send_request(
service='tasks', service='tasks',
@ -1913,7 +1920,6 @@ class Worker(ServiceCommandSection):
if current_task.script.binary and current_task.script.binary.startswith('python') and \ if current_task.script.binary and current_task.script.binary.startswith('python') and \
execution.entry_point and execution.entry_point.split()[0].strip() == '-m': execution.entry_point and execution.entry_point.split()[0].strip() == '-m':
# we need to split it # we need to split it
import shlex
extra.extend(shlex.split(execution.entry_point)) extra.extend(shlex.split(execution.entry_point))
else: else:
extra.append(execution.entry_point) extra.append(execution.entry_point)