Fix jsonargparse subcommand config parsing

Fix Lightning integration crashes when a config entry contains "." in its name
This commit is contained in:
allegroai 2024-03-06 09:08:57 +02:00
parent f51ed621f7
commit a1f3279719

View File

@ -1,17 +1,35 @@
import json import json
import copy
import logging import logging
try: try:
# public import capabilities of namespace, util, actions will be deprecated
# import from "protected" instead
from jsonargparse._namespace import Namespace
# noinspection PyProtectedMember
from jsonargparse._util import Path
# noinspection PyProtectedMember
from jsonargparse import ArgumentParser from jsonargparse import ArgumentParser
from jsonargparse.namespace import Namespace
from jsonargparse.util import Path, change_to_path_dir
except ImportError: except ImportError:
ArgumentParser = None try:
from jsonargparse.namespace import Namespace
from jsonargparse.util import Path
from jsonargparse import ArgumentParser
except ImportError:
ArgumentParser = None
try: try:
import jsonargparse.typehints as jsonargparse_typehints # public import capabilities of jsonargparse_typehints will be deprecated
# import from "protected" instead
# noinspection PyProtectedMember
import jsonargparse._typehints as jsonargparse_typehints
except ImportError: except ImportError:
jsonargparse_typehints = None try:
import jsonargparse.typehints as jsonargparse_typehints
except ImportError:
jsonargparse_typehints = None
from ..config import running_remotely, get_remote_task_id from ..config import running_remotely, get_remote_task_id
from .frameworks import _patched_call # noqa from .frameworks import _patched_call # noqa
@ -122,17 +140,19 @@ class PatchJsonArgParse(object):
try: try:
PatchJsonArgParse._load_task_params(parser=obj) PatchJsonArgParse._load_task_params(parser=obj)
params = PatchJsonArgParse.__remote_task_params_dict params = PatchJsonArgParse.__remote_task_params_dict
params_namespace = Namespace()
for k, v in params.items():
params_namespace[k] = v
allow_jsonargparse_overrides_value = True allow_jsonargparse_overrides_value = True
if PatchJsonArgParse._allow_jsonargparse_overrides in params: if PatchJsonArgParse._allow_jsonargparse_overrides in params:
allow_jsonargparse_overrides_value = params.pop(PatchJsonArgParse._allow_jsonargparse_overrides) allow_jsonargparse_overrides_value = params.pop(PatchJsonArgParse._allow_jsonargparse_overrides)
if PatchJsonArgParse._ignore_ui_overrides in params: if PatchJsonArgParse._ignore_ui_overrides in params:
allow_jsonargparse_overrides_value = not params.pop(PatchJsonArgParse._ignore_ui_overrides) allow_jsonargparse_overrides_value = not params.pop(PatchJsonArgParse._ignore_ui_overrides)
params_namespace = Namespace()
for k, v in params.items():
params_namespace[k] = v
if not allow_jsonargparse_overrides_value: if not allow_jsonargparse_overrides_value:
params_namespace = PatchJsonArgParse.__restore_args( params_namespace = PatchJsonArgParse.__restore_args(
obj, params_namespace, subcommand=params_namespace.get(PatchJsonArgParse._command_name) obj,
params_namespace,
subcommand=params_namespace.get(PatchJsonArgParse._command_name)
) )
if PatchJsonArgParse._allow_jsonargparse_overrides in params_namespace: if PatchJsonArgParse._allow_jsonargparse_overrides in params_namespace:
del params_namespace[PatchJsonArgParse._allow_jsonargparse_overrides] del params_namespace[PatchJsonArgParse._allow_jsonargparse_overrides]
@ -210,27 +230,37 @@ class PatchJsonArgParse(object):
@staticmethod @staticmethod
def __get_paths_from_dict(dict_): def __get_paths_from_dict(dict_):
paths = [path for path in dict_.values() if isinstance(path, Path)] paths = [(path_key, path) for path_key, path in dict_.items() if isinstance(path, Path)]
for subargs in dict_.values(): for subargs_key, subargs in dict_.items():
if isinstance(subargs, list) and all(isinstance(path, Path) for path in subargs): if isinstance(subargs, list) and all(isinstance(path, Path) for path in subargs):
paths.extend(subargs) paths.extend((subargs_key, path) for path in subargs)
return paths return paths
@staticmethod @staticmethod
def __get_args_from_path(parser, path, subcommand=None): def __get_args_from_path(parser, path, subcommand=None):
with change_to_path_dir(path): try:
parsed_cfg = parser.parse_string(path.get_content(), _skip_check=True, _fail_no_subcommand=False) # make sure no side effects happen in parser
if subcommand: parser = copy.deepcopy(parser)
parsed_cfg = { argument = path[0]
( if subcommand and argument.startswith(subcommand + PatchJsonArgParse._commands_sep):
(subcommand + PatchJsonArgParse._commands_sep) argument = argument[len(subcommand + PatchJsonArgParse._commands_sep):]
if k not in PatchJsonArgParse._special_fields result = parser.parse_args(
else "" [subcommand, parser.prefix_chars[0] * 2 + argument, path[1].rel_path],
) _skip_check=True,
+ k: v defaults=False,
for k, v in parsed_cfg.items() )
} if PatchJsonArgParse._command_name in result:
return parsed_cfg del result[PatchJsonArgParse._command_name]
else:
result = parser.parse_args(
[parser.prefix_chars[0] * 2 + argument, path[1].rel_path], _skip_check=True, defaults=False
)
if argument in result:
del result[argument]
return result
except Exception as e:
logging.getLogger(__file__).warning("Failed parsing jsonargparse config: {}".format(e))
return Namespace()
@staticmethod @staticmethod
def _handle_namespace(value): def _handle_namespace(value):