Allow using "$", "." and whitespaces in hyper-parameter keys

This commit is contained in:
allegroai
2020-01-02 15:28:50 +02:00
parent 7d10bbdf8e
commit dedac3b2fe
7 changed files with 167 additions and 82 deletions

View File

@@ -1,4 +1,3 @@
import re
from collections import OrderedDict
from datetime import datetime, timedelta
from operator import attrgetter
@@ -32,8 +31,7 @@ from service_repo import APICall
from timing_context import TimingContext
from utilities.dicts import deep_merge
from utilities.threads_manager import ThreadsManager
from .utils import ChangeStatusRequest, validate_status_change
from .utils import ChangeStatusRequest, validate_status_change, ParameterKeyEscaper
log = config.logger(__file__)
@@ -171,6 +169,11 @@ class TaskBLL(object):
task = cls.get_by_id(company_id=company_id, task_id=task_id)
execution_dict = task.execution.to_proper_dict() if task.execution else {}
if execution_overrides:
parameters = execution_overrides.get("parameters")
if parameters is not None:
execution_overrides["parameters"] = {
ParameterKeyEscaper.escape(k): v for k, v in parameters.items()
}
execution_dict = deep_merge(execution_dict, execution_overrides)
artifacts = execution_dict.get("artifacts")
if artifacts:
@@ -178,25 +181,28 @@ class TaskBLL(object):
a for a in artifacts if a.get("mode") != ArtifactModes.output
]
now = datetime.utcnow()
new_task = Task(
id=create_id(),
user=user_id,
company=company_id,
created=now,
last_update=now,
name=name or task.name,
comment=comment or task.comment,
parent=parent or task.parent,
project=project or task.project,
tags=tags or task.tags,
system_tags=system_tags or [],
type=task.type,
script=task.script,
output=Output(destination=task.output.destination) if task.output else None,
execution=execution_dict,
)
cls.validate(new_task)
new_task.save()
with translate_errors_context():
new_task = Task(
id=create_id(),
user=user_id,
company=company_id,
created=now,
last_update=now,
name=name or task.name,
comment=comment or task.comment,
parent=parent or task.parent,
project=project or task.project,
tags=tags or task.tags,
system_tags=system_tags or [],
type=task.type,
script=task.script,
output=Output(destination=task.output.destination) if task.output else None,
execution=execution_dict,
)
cls.validate(new_task)
new_task.save()
return new_task
@classmethod
@@ -215,18 +221,6 @@ class TaskBLL(object):
cls.validate_execution_model(task)
if task.execution:
if task.execution.parameters:
cls._validate_execution_parameters(task.execution.parameters)
@staticmethod
def _validate_execution_parameters(parameters):
invalid_keys = [k for k in parameters if re.search(r"\s", k)]
if invalid_keys:
raise errors.bad_request.ValidationError(
"execution.parameters keys contain whitespace", keys=invalid_keys
)
@staticmethod
def get_unique_metric_variants(company_id, project_ids=None):
pipeline = [
@@ -658,7 +652,10 @@ class TaskBLL(object):
if result:
total = int(result.get("total", -1))
results = [r["_id"] for r in result.get("results", [])]
results = [
ParameterKeyEscaper.unescape(r["_id"])
for r in result.get("results", [])
]
remaining = max(0, total - (len(results) + page * page_size))
return total, remaining, results