mirror of
https://github.com/clearml/clearml-server
synced 2025-06-26 23:15:47 +00:00
Allow using "$", "." and whitespaces in hyper-parameter keys
This commit is contained in:
@@ -1,4 +1,3 @@
|
||||
import re
|
||||
from collections import OrderedDict
|
||||
from datetime import datetime, timedelta
|
||||
from operator import attrgetter
|
||||
@@ -32,8 +31,7 @@ from service_repo import APICall
|
||||
from timing_context import TimingContext
|
||||
from utilities.dicts import deep_merge
|
||||
from utilities.threads_manager import ThreadsManager
|
||||
from .utils import ChangeStatusRequest, validate_status_change
|
||||
|
||||
from .utils import ChangeStatusRequest, validate_status_change, ParameterKeyEscaper
|
||||
|
||||
log = config.logger(__file__)
|
||||
|
||||
@@ -171,6 +169,11 @@ class TaskBLL(object):
|
||||
task = cls.get_by_id(company_id=company_id, task_id=task_id)
|
||||
execution_dict = task.execution.to_proper_dict() if task.execution else {}
|
||||
if execution_overrides:
|
||||
parameters = execution_overrides.get("parameters")
|
||||
if parameters is not None:
|
||||
execution_overrides["parameters"] = {
|
||||
ParameterKeyEscaper.escape(k): v for k, v in parameters.items()
|
||||
}
|
||||
execution_dict = deep_merge(execution_dict, execution_overrides)
|
||||
artifacts = execution_dict.get("artifacts")
|
||||
if artifacts:
|
||||
@@ -178,25 +181,28 @@ class TaskBLL(object):
|
||||
a for a in artifacts if a.get("mode") != ArtifactModes.output
|
||||
]
|
||||
now = datetime.utcnow()
|
||||
new_task = Task(
|
||||
id=create_id(),
|
||||
user=user_id,
|
||||
company=company_id,
|
||||
created=now,
|
||||
last_update=now,
|
||||
name=name or task.name,
|
||||
comment=comment or task.comment,
|
||||
parent=parent or task.parent,
|
||||
project=project or task.project,
|
||||
tags=tags or task.tags,
|
||||
system_tags=system_tags or [],
|
||||
type=task.type,
|
||||
script=task.script,
|
||||
output=Output(destination=task.output.destination) if task.output else None,
|
||||
execution=execution_dict,
|
||||
)
|
||||
cls.validate(new_task)
|
||||
new_task.save()
|
||||
|
||||
with translate_errors_context():
|
||||
new_task = Task(
|
||||
id=create_id(),
|
||||
user=user_id,
|
||||
company=company_id,
|
||||
created=now,
|
||||
last_update=now,
|
||||
name=name or task.name,
|
||||
comment=comment or task.comment,
|
||||
parent=parent or task.parent,
|
||||
project=project or task.project,
|
||||
tags=tags or task.tags,
|
||||
system_tags=system_tags or [],
|
||||
type=task.type,
|
||||
script=task.script,
|
||||
output=Output(destination=task.output.destination) if task.output else None,
|
||||
execution=execution_dict,
|
||||
)
|
||||
cls.validate(new_task)
|
||||
new_task.save()
|
||||
|
||||
return new_task
|
||||
|
||||
@classmethod
|
||||
@@ -215,18 +221,6 @@ class TaskBLL(object):
|
||||
|
||||
cls.validate_execution_model(task)
|
||||
|
||||
if task.execution:
|
||||
if task.execution.parameters:
|
||||
cls._validate_execution_parameters(task.execution.parameters)
|
||||
|
||||
@staticmethod
|
||||
def _validate_execution_parameters(parameters):
|
||||
invalid_keys = [k for k in parameters if re.search(r"\s", k)]
|
||||
if invalid_keys:
|
||||
raise errors.bad_request.ValidationError(
|
||||
"execution.parameters keys contain whitespace", keys=invalid_keys
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_unique_metric_variants(company_id, project_ids=None):
|
||||
pipeline = [
|
||||
@@ -658,7 +652,10 @@ class TaskBLL(object):
|
||||
|
||||
if result:
|
||||
total = int(result.get("total", -1))
|
||||
results = [r["_id"] for r in result.get("results", [])]
|
||||
results = [
|
||||
ParameterKeyEscaper.unescape(r["_id"])
|
||||
for r in result.get("results", [])
|
||||
]
|
||||
remaining = max(0, total - (len(results) + page * page_size))
|
||||
|
||||
return total, remaining, results
|
||||
|
||||
Reference in New Issue
Block a user