mirror of
https://github.com/clearml/clearml-server
synced 2025-03-03 02:33:02 +00:00
Allow setting status_message in tasks.update
Optimizations and refactoring
This commit is contained in:
parent
ec2e071ab7
commit
158da9b480
@ -234,7 +234,7 @@ class Task(AttributedDocument):
|
||||
type = StringField(required=True, choices=get_options(TaskType))
|
||||
status = StringField(default=TaskStatus.created, choices=get_options(TaskStatus))
|
||||
status_reason = StringField()
|
||||
status_message = StringField()
|
||||
status_message = StringField(user_set_allowed=True)
|
||||
status_changed = DateTimeField()
|
||||
comment = StringField(user_set_allowed=True)
|
||||
created = DateTimeField(required=True, user_set_allowed=True)
|
||||
|
@ -1,9 +1,12 @@
|
||||
import random
|
||||
import string
|
||||
|
||||
sys_random = random.SystemRandom()
|
||||
|
||||
|
||||
def get_random_string(length=12, allowed_chars='abcdefghijklmnopqrstuvwxyz'
|
||||
'ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789'):
|
||||
def get_random_string(
|
||||
length: int = 12, allowed_chars: str = string.ascii_letters + string.digits
|
||||
) -> str:
|
||||
"""
|
||||
Returns a securely generated random string.
|
||||
|
||||
@ -12,20 +15,20 @@ def get_random_string(length=12, allowed_chars='abcdefghijklmnopqrstuvwxyz'
|
||||
|
||||
Taken from the django.utils.crypto module.
|
||||
"""
|
||||
return ''.join(sys_random.choice(allowed_chars) for _ in range(length))
|
||||
return "".join(sys_random.choice(allowed_chars) for _ in range(length))
|
||||
|
||||
|
||||
def get_client_id(length=20):
|
||||
def get_client_id(length: int = 20) -> str:
|
||||
"""
|
||||
Create a random secret key.
|
||||
|
||||
Taken from the Django project.
|
||||
"""
|
||||
chars = 'ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789'
|
||||
chars = string.ascii_uppercase + string.digits
|
||||
return get_random_string(length, chars)
|
||||
|
||||
|
||||
def get_secret_key(length=50):
|
||||
def get_secret_key(length: int = 50) -> str:
|
||||
"""
|
||||
Create a random secret key.
|
||||
|
||||
@ -33,5 +36,5 @@ def get_secret_key(length=50):
|
||||
NOTE: asterisk is not supported due to issues with environment variables containing
|
||||
asterisks (in case the secret key is stored in an environment variable)
|
||||
"""
|
||||
chars = 'abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789'
|
||||
chars = string.ascii_letters + string.digits
|
||||
return get_random_string(length, chars)
|
||||
|
@ -4,7 +4,6 @@ from functools import partial
|
||||
from typing import Sequence, Union, Tuple
|
||||
|
||||
import attr
|
||||
import dpath
|
||||
from mongoengine import EmbeddedDocument, Q
|
||||
from mongoengine.queryset.transform import COMPARISON_OPERATORS
|
||||
from pymongo import UpdateOne
|
||||
@ -220,14 +219,13 @@ def get_all_ex(call: APICall, company_id, _):
|
||||
|
||||
call_data = escape_execution_parameters(call)
|
||||
|
||||
with translate_errors_context():
|
||||
with TimingContext("mongo", "task_get_all_ex"):
|
||||
_process_include_subprojects(call_data)
|
||||
tasks = Task.get_many_with_join(
|
||||
company=company_id, query_dict=call_data, allow_public=True,
|
||||
)
|
||||
unprepare_from_saved(call, tasks)
|
||||
call.result.data = {"tasks": tasks}
|
||||
with TimingContext("mongo", "task_get_all_ex"):
|
||||
_process_include_subprojects(call_data)
|
||||
tasks = Task.get_many_with_join(
|
||||
company=company_id, query_dict=call_data, allow_public=True,
|
||||
)
|
||||
unprepare_from_saved(call, tasks)
|
||||
call.result.data = {"tasks": tasks}
|
||||
|
||||
|
||||
@endpoint("tasks.get_by_id_ex", required_fields=["id"])
|
||||
@ -236,14 +234,13 @@ def get_by_id_ex(call: APICall, company_id, _):
|
||||
|
||||
call_data = escape_execution_parameters(call)
|
||||
|
||||
with translate_errors_context():
|
||||
with TimingContext("mongo", "task_get_by_id_ex"):
|
||||
tasks = Task.get_many_with_join(
|
||||
company=company_id, query_dict=call_data, allow_public=True,
|
||||
)
|
||||
with TimingContext("mongo", "task_get_by_id_ex"):
|
||||
tasks = Task.get_many_with_join(
|
||||
company=company_id, query_dict=call_data, allow_public=True,
|
||||
)
|
||||
|
||||
unprepare_from_saved(call, tasks)
|
||||
call.result.data = {"tasks": tasks}
|
||||
unprepare_from_saved(call, tasks)
|
||||
call.result.data = {"tasks": tasks}
|
||||
|
||||
|
||||
@endpoint("tasks.get_all", required_fields=[])
|
||||
@ -252,16 +249,15 @@ def get_all(call: APICall, company_id, _):
|
||||
|
||||
call_data = escape_execution_parameters(call)
|
||||
|
||||
with translate_errors_context():
|
||||
with TimingContext("mongo", "task_get_all"):
|
||||
tasks = Task.get_many(
|
||||
company=company_id,
|
||||
parameters=call_data,
|
||||
query_dict=call_data,
|
||||
allow_public=True,
|
||||
)
|
||||
unprepare_from_saved(call, tasks)
|
||||
call.result.data = {"tasks": tasks}
|
||||
with TimingContext("mongo", "task_get_all"):
|
||||
tasks = Task.get_many(
|
||||
company=company_id,
|
||||
parameters=call_data,
|
||||
query_dict=call_data,
|
||||
allow_public=True,
|
||||
)
|
||||
unprepare_from_saved(call, tasks)
|
||||
call.result.data = {"tasks": tasks}
|
||||
|
||||
|
||||
@endpoint("tasks.get_types", request_data_model=GetTypesRequest)
|
||||
@ -403,15 +399,12 @@ def prepare_for_save(call: APICall, fields: dict, previous_task: Task = None):
|
||||
escape_dict_field(fields, path)
|
||||
|
||||
# Strip all script fields (remove leading and trailing whitespace chars) to avoid unusable names and paths
|
||||
for field in task_script_stripped_fields:
|
||||
try:
|
||||
path = f"script/{field}"
|
||||
value = dpath.get(fields, path)
|
||||
script = fields.get("script")
|
||||
if script:
|
||||
for field in task_script_stripped_fields:
|
||||
value = script.get(field)
|
||||
if isinstance(value, str):
|
||||
value = value.strip()
|
||||
dpath.set(fields, path, value)
|
||||
except KeyError:
|
||||
pass
|
||||
script[field] = value.strip()
|
||||
|
||||
return fields
|
||||
|
||||
@ -546,10 +539,12 @@ def clone_task(call: APICall, company_id, request: CloneRequest):
|
||||
}
|
||||
|
||||
|
||||
def prepare_update_fields(call: APICall, task, call_data):
|
||||
def prepare_update_fields(call: APICall, call_data):
|
||||
valid_fields = deepcopy(Task.user_set_allowed())
|
||||
update_fields = {k: v for k, v in create_fields.items() if k in valid_fields}
|
||||
update_fields["output__error"] = None
|
||||
update_fields.update(
|
||||
status=None, status_reason=None, status_message=None, output__error=None
|
||||
)
|
||||
t_fields = task_fields
|
||||
t_fields.add("output__error")
|
||||
fields = parse_from_call(call_data, update_fields, t_fields)
|
||||
@ -569,7 +564,7 @@ def update(call: APICall, company_id, req_model: UpdateRequest):
|
||||
if not task:
|
||||
raise errors.bad_request.InvalidTaskId(id=task_id)
|
||||
|
||||
partial_update_dict, valid_fields = prepare_update_fields(call, task, call.data)
|
||||
partial_update_dict, valid_fields = prepare_update_fields(call, call.data)
|
||||
|
||||
if not partial_update_dict:
|
||||
return UpdateResponse(updated=0)
|
||||
@ -642,7 +637,7 @@ def update_batch(call: APICall, company_id, _):
|
||||
updated_projects = set()
|
||||
for id, data in items.items():
|
||||
task = tasks[id]
|
||||
fields, valid_fields = prepare_update_fields(call, task, data)
|
||||
fields, valid_fields = prepare_update_fields(call, data)
|
||||
partial_update_dict = Task.get_safe_update_dict(fields)
|
||||
if not partial_update_dict:
|
||||
continue
|
||||
@ -744,8 +739,7 @@ def edit(call: APICall, company_id, req_model: UpdateRequest):
|
||||
"tasks.get_hyper_params", request_data_model=GetHyperParamsRequest,
|
||||
)
|
||||
def get_hyper_params(call: APICall, company_id, request: GetHyperParamsRequest):
|
||||
with translate_errors_context():
|
||||
tasks_params = HyperParams.get_params(company_id, task_ids=request.tasks)
|
||||
tasks_params = HyperParams.get_params(company_id, task_ids=request.tasks)
|
||||
|
||||
call.result.data = {
|
||||
"params": [{"task": task, **data} for task, data in tasks_params.items()]
|
||||
@ -754,39 +748,36 @@ def get_hyper_params(call: APICall, company_id, request: GetHyperParamsRequest):
|
||||
|
||||
@endpoint("tasks.edit_hyper_params", request_data_model=EditHyperParamsRequest)
|
||||
def edit_hyper_params(call: APICall, company_id, request: EditHyperParamsRequest):
|
||||
with translate_errors_context():
|
||||
call.result.data = {
|
||||
"updated": HyperParams.edit_params(
|
||||
company_id,
|
||||
task_id=request.task,
|
||||
hyperparams=request.hyperparams,
|
||||
replace_hyperparams=request.replace_hyperparams,
|
||||
force=request.force,
|
||||
)
|
||||
}
|
||||
call.result.data = {
|
||||
"updated": HyperParams.edit_params(
|
||||
company_id,
|
||||
task_id=request.task,
|
||||
hyperparams=request.hyperparams,
|
||||
replace_hyperparams=request.replace_hyperparams,
|
||||
force=request.force,
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
@endpoint("tasks.delete_hyper_params", request_data_model=DeleteHyperParamsRequest)
|
||||
def delete_hyper_params(call: APICall, company_id, request: DeleteHyperParamsRequest):
|
||||
with translate_errors_context():
|
||||
call.result.data = {
|
||||
"deleted": HyperParams.delete_params(
|
||||
company_id,
|
||||
task_id=request.task,
|
||||
hyperparams=request.hyperparams,
|
||||
force=request.force,
|
||||
)
|
||||
}
|
||||
call.result.data = {
|
||||
"deleted": HyperParams.delete_params(
|
||||
company_id,
|
||||
task_id=request.task,
|
||||
hyperparams=request.hyperparams,
|
||||
force=request.force,
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
@endpoint(
|
||||
"tasks.get_configurations", request_data_model=GetConfigurationsRequest,
|
||||
)
|
||||
def get_configurations(call: APICall, company_id, request: GetConfigurationsRequest):
|
||||
with translate_errors_context():
|
||||
tasks_params = HyperParams.get_configurations(
|
||||
company_id, task_ids=request.tasks, names=request.names
|
||||
)
|
||||
tasks_params = HyperParams.get_configurations(
|
||||
company_id, task_ids=request.tasks, names=request.names
|
||||
)
|
||||
|
||||
call.result.data = {
|
||||
"configurations": [
|
||||
@ -801,10 +792,9 @@ def get_configurations(call: APICall, company_id, request: GetConfigurationsRequ
|
||||
def get_configuration_names(
|
||||
call: APICall, company_id, request: GetConfigurationNamesRequest
|
||||
):
|
||||
with translate_errors_context():
|
||||
tasks_params = HyperParams.get_configuration_names(
|
||||
company_id, task_ids=request.tasks, skip_empty=request.skip_empty
|
||||
)
|
||||
tasks_params = HyperParams.get_configuration_names(
|
||||
company_id, task_ids=request.tasks, skip_empty=request.skip_empty
|
||||
)
|
||||
|
||||
call.result.data = {
|
||||
"configurations": [
|
||||
@ -815,31 +805,29 @@ def get_configuration_names(
|
||||
|
||||
@endpoint("tasks.edit_configuration", request_data_model=EditConfigurationRequest)
|
||||
def edit_configuration(call: APICall, company_id, request: EditConfigurationRequest):
|
||||
with translate_errors_context():
|
||||
call.result.data = {
|
||||
"updated": HyperParams.edit_configuration(
|
||||
company_id,
|
||||
task_id=request.task,
|
||||
configuration=request.configuration,
|
||||
replace_configuration=request.replace_configuration,
|
||||
force=request.force,
|
||||
)
|
||||
}
|
||||
call.result.data = {
|
||||
"updated": HyperParams.edit_configuration(
|
||||
company_id,
|
||||
task_id=request.task,
|
||||
configuration=request.configuration,
|
||||
replace_configuration=request.replace_configuration,
|
||||
force=request.force,
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
@endpoint("tasks.delete_configuration", request_data_model=DeleteConfigurationRequest)
|
||||
def delete_configuration(
|
||||
call: APICall, company_id, request: DeleteConfigurationRequest
|
||||
):
|
||||
with translate_errors_context():
|
||||
call.result.data = {
|
||||
"deleted": HyperParams.delete_configuration(
|
||||
company_id,
|
||||
task_id=request.task,
|
||||
configuration=request.configuration,
|
||||
force=request.force,
|
||||
)
|
||||
}
|
||||
call.result.data = {
|
||||
"deleted": HyperParams.delete_configuration(
|
||||
company_id,
|
||||
task_id=request.task,
|
||||
configuration=request.configuration,
|
||||
force=request.force,
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
@endpoint(
|
||||
@ -1170,15 +1158,14 @@ def ping(_, company_id, request: PingRequest):
|
||||
def add_or_update_artifacts(
|
||||
call: APICall, company_id, request: AddOrUpdateArtifactsRequest
|
||||
):
|
||||
with translate_errors_context():
|
||||
call.result.data = {
|
||||
"updated": Artifacts.add_or_update_artifacts(
|
||||
company_id=company_id,
|
||||
task_id=request.task,
|
||||
artifacts=request.artifacts,
|
||||
force=request.force,
|
||||
)
|
||||
}
|
||||
call.result.data = {
|
||||
"updated": Artifacts.add_or_update_artifacts(
|
||||
company_id=company_id,
|
||||
task_id=request.task,
|
||||
artifacts=request.artifacts,
|
||||
force=request.force,
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
@endpoint(
|
||||
@ -1187,31 +1174,28 @@ def add_or_update_artifacts(
|
||||
request_data_model=DeleteArtifactsRequest,
|
||||
)
|
||||
def delete_artifacts(call: APICall, company_id, request: DeleteArtifactsRequest):
|
||||
with translate_errors_context():
|
||||
call.result.data = {
|
||||
"deleted": Artifacts.delete_artifacts(
|
||||
company_id=company_id,
|
||||
task_id=request.task,
|
||||
artifact_ids=request.artifacts,
|
||||
force=request.force,
|
||||
)
|
||||
}
|
||||
call.result.data = {
|
||||
"deleted": Artifacts.delete_artifacts(
|
||||
company_id=company_id,
|
||||
task_id=request.task,
|
||||
artifact_ids=request.artifacts,
|
||||
force=request.force,
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
@endpoint("tasks.make_public", min_version="2.9", request_data_model=MakePublicRequest)
|
||||
def make_public(call: APICall, company_id, request: MakePublicRequest):
|
||||
with translate_errors_context():
|
||||
call.result.data = Task.set_public(
|
||||
company_id, request.ids, invalid_cls=InvalidTaskId, enabled=True
|
||||
)
|
||||
call.result.data = Task.set_public(
|
||||
company_id, request.ids, invalid_cls=InvalidTaskId, enabled=True
|
||||
)
|
||||
|
||||
|
||||
@endpoint("tasks.make_private", min_version="2.9", request_data_model=MakePublicRequest)
|
||||
def make_public(call: APICall, company_id, request: MakePublicRequest):
|
||||
with translate_errors_context():
|
||||
call.result.data = Task.set_public(
|
||||
company_id, request.ids, invalid_cls=InvalidTaskId, enabled=False
|
||||
)
|
||||
call.result.data = Task.set_public(
|
||||
company_id, request.ids, invalid_cls=InvalidTaskId, enabled=False
|
||||
)
|
||||
|
||||
|
||||
@endpoint("tasks.move", request_data_model=MoveRequest)
|
||||
|
Loading…
Reference in New Issue
Block a user