From 158da9b48053ec3fdc5e6883c8c4c786950bb346 Mon Sep 17 00:00:00 2001 From: allegroai <> Date: Sun, 25 Jul 2021 14:35:36 +0300 Subject: [PATCH] Allow setting status_message in tasks.update Optimizations and refactoring --- apiserver/database/model/task/task.py | 2 +- apiserver/service_repo/auth/utils.py | 17 ++- apiserver/services/tasks.py | 208 ++++++++++++-------------- 3 files changed, 107 insertions(+), 120 deletions(-) diff --git a/apiserver/database/model/task/task.py b/apiserver/database/model/task/task.py index 586485d..f05b6db 100644 --- a/apiserver/database/model/task/task.py +++ b/apiserver/database/model/task/task.py @@ -234,7 +234,7 @@ class Task(AttributedDocument): type = StringField(required=True, choices=get_options(TaskType)) status = StringField(default=TaskStatus.created, choices=get_options(TaskStatus)) status_reason = StringField() - status_message = StringField() + status_message = StringField(user_set_allowed=True) status_changed = DateTimeField() comment = StringField(user_set_allowed=True) created = DateTimeField(required=True, user_set_allowed=True) diff --git a/apiserver/service_repo/auth/utils.py b/apiserver/service_repo/auth/utils.py index 6df9478..b514434 100644 --- a/apiserver/service_repo/auth/utils.py +++ b/apiserver/service_repo/auth/utils.py @@ -1,9 +1,12 @@ import random +import string + sys_random = random.SystemRandom() -def get_random_string(length=12, allowed_chars='abcdefghijklmnopqrstuvwxyz' - 'ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789'): +def get_random_string( + length: int = 12, allowed_chars: str = string.ascii_letters + string.digits +) -> str: """ Returns a securely generated random string. @@ -12,20 +15,20 @@ def get_random_string(length=12, allowed_chars='abcdefghijklmnopqrstuvwxyz' Taken from the django.utils.crypto module. """ - return ''.join(sys_random.choice(allowed_chars) for _ in range(length)) + return "".join(sys_random.choice(allowed_chars) for _ in range(length)) -def get_client_id(length=20): +def get_client_id(length: int = 20) -> str: """ Create a random secret key. Taken from the Django project. """ - chars = 'ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789' + chars = string.ascii_uppercase + string.digits return get_random_string(length, chars) -def get_secret_key(length=50): +def get_secret_key(length: int = 50) -> str: """ Create a random secret key. @@ -33,5 +36,5 @@ def get_secret_key(length=50): NOTE: asterisk is not supported due to issues with environment variables containing asterisks (in case the secret key is stored in an environment variable) """ - chars = 'abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789' + chars = string.ascii_letters + string.digits return get_random_string(length, chars) diff --git a/apiserver/services/tasks.py b/apiserver/services/tasks.py index 95b90e9..a1520f5 100644 --- a/apiserver/services/tasks.py +++ b/apiserver/services/tasks.py @@ -4,7 +4,6 @@ from functools import partial from typing import Sequence, Union, Tuple import attr -import dpath from mongoengine import EmbeddedDocument, Q from mongoengine.queryset.transform import COMPARISON_OPERATORS from pymongo import UpdateOne @@ -220,14 +219,13 @@ def get_all_ex(call: APICall, company_id, _): call_data = escape_execution_parameters(call) - with translate_errors_context(): - with TimingContext("mongo", "task_get_all_ex"): - _process_include_subprojects(call_data) - tasks = Task.get_many_with_join( - company=company_id, query_dict=call_data, allow_public=True, - ) - unprepare_from_saved(call, tasks) - call.result.data = {"tasks": tasks} + with TimingContext("mongo", "task_get_all_ex"): + _process_include_subprojects(call_data) + tasks = Task.get_many_with_join( + company=company_id, query_dict=call_data, allow_public=True, + ) + unprepare_from_saved(call, tasks) + call.result.data = {"tasks": tasks} @endpoint("tasks.get_by_id_ex", required_fields=["id"]) @@ -236,14 +234,13 @@ def get_by_id_ex(call: APICall, company_id, _): call_data = escape_execution_parameters(call) - with translate_errors_context(): - with TimingContext("mongo", "task_get_by_id_ex"): - tasks = Task.get_many_with_join( - company=company_id, query_dict=call_data, allow_public=True, - ) + with TimingContext("mongo", "task_get_by_id_ex"): + tasks = Task.get_many_with_join( + company=company_id, query_dict=call_data, allow_public=True, + ) - unprepare_from_saved(call, tasks) - call.result.data = {"tasks": tasks} + unprepare_from_saved(call, tasks) + call.result.data = {"tasks": tasks} @endpoint("tasks.get_all", required_fields=[]) @@ -252,16 +249,15 @@ def get_all(call: APICall, company_id, _): call_data = escape_execution_parameters(call) - with translate_errors_context(): - with TimingContext("mongo", "task_get_all"): - tasks = Task.get_many( - company=company_id, - parameters=call_data, - query_dict=call_data, - allow_public=True, - ) - unprepare_from_saved(call, tasks) - call.result.data = {"tasks": tasks} + with TimingContext("mongo", "task_get_all"): + tasks = Task.get_many( + company=company_id, + parameters=call_data, + query_dict=call_data, + allow_public=True, + ) + unprepare_from_saved(call, tasks) + call.result.data = {"tasks": tasks} @endpoint("tasks.get_types", request_data_model=GetTypesRequest) @@ -403,15 +399,12 @@ def prepare_for_save(call: APICall, fields: dict, previous_task: Task = None): escape_dict_field(fields, path) # Strip all script fields (remove leading and trailing whitespace chars) to avoid unusable names and paths - for field in task_script_stripped_fields: - try: - path = f"script/{field}" - value = dpath.get(fields, path) + script = fields.get("script") + if script: + for field in task_script_stripped_fields: + value = script.get(field) if isinstance(value, str): - value = value.strip() - dpath.set(fields, path, value) - except KeyError: - pass + script[field] = value.strip() return fields @@ -546,10 +539,12 @@ def clone_task(call: APICall, company_id, request: CloneRequest): } -def prepare_update_fields(call: APICall, task, call_data): +def prepare_update_fields(call: APICall, call_data): valid_fields = deepcopy(Task.user_set_allowed()) update_fields = {k: v for k, v in create_fields.items() if k in valid_fields} - update_fields["output__error"] = None + update_fields.update( + status=None, status_reason=None, status_message=None, output__error=None + ) t_fields = task_fields t_fields.add("output__error") fields = parse_from_call(call_data, update_fields, t_fields) @@ -569,7 +564,7 @@ def update(call: APICall, company_id, req_model: UpdateRequest): if not task: raise errors.bad_request.InvalidTaskId(id=task_id) - partial_update_dict, valid_fields = prepare_update_fields(call, task, call.data) + partial_update_dict, valid_fields = prepare_update_fields(call, call.data) if not partial_update_dict: return UpdateResponse(updated=0) @@ -642,7 +637,7 @@ def update_batch(call: APICall, company_id, _): updated_projects = set() for id, data in items.items(): task = tasks[id] - fields, valid_fields = prepare_update_fields(call, task, data) + fields, valid_fields = prepare_update_fields(call, data) partial_update_dict = Task.get_safe_update_dict(fields) if not partial_update_dict: continue @@ -744,8 +739,7 @@ def edit(call: APICall, company_id, req_model: UpdateRequest): "tasks.get_hyper_params", request_data_model=GetHyperParamsRequest, ) def get_hyper_params(call: APICall, company_id, request: GetHyperParamsRequest): - with translate_errors_context(): - tasks_params = HyperParams.get_params(company_id, task_ids=request.tasks) + tasks_params = HyperParams.get_params(company_id, task_ids=request.tasks) call.result.data = { "params": [{"task": task, **data} for task, data in tasks_params.items()] @@ -754,39 +748,36 @@ def get_hyper_params(call: APICall, company_id, request: GetHyperParamsRequest): @endpoint("tasks.edit_hyper_params", request_data_model=EditHyperParamsRequest) def edit_hyper_params(call: APICall, company_id, request: EditHyperParamsRequest): - with translate_errors_context(): - call.result.data = { - "updated": HyperParams.edit_params( - company_id, - task_id=request.task, - hyperparams=request.hyperparams, - replace_hyperparams=request.replace_hyperparams, - force=request.force, - ) - } + call.result.data = { + "updated": HyperParams.edit_params( + company_id, + task_id=request.task, + hyperparams=request.hyperparams, + replace_hyperparams=request.replace_hyperparams, + force=request.force, + ) + } @endpoint("tasks.delete_hyper_params", request_data_model=DeleteHyperParamsRequest) def delete_hyper_params(call: APICall, company_id, request: DeleteHyperParamsRequest): - with translate_errors_context(): - call.result.data = { - "deleted": HyperParams.delete_params( - company_id, - task_id=request.task, - hyperparams=request.hyperparams, - force=request.force, - ) - } + call.result.data = { + "deleted": HyperParams.delete_params( + company_id, + task_id=request.task, + hyperparams=request.hyperparams, + force=request.force, + ) + } @endpoint( "tasks.get_configurations", request_data_model=GetConfigurationsRequest, ) def get_configurations(call: APICall, company_id, request: GetConfigurationsRequest): - with translate_errors_context(): - tasks_params = HyperParams.get_configurations( - company_id, task_ids=request.tasks, names=request.names - ) + tasks_params = HyperParams.get_configurations( + company_id, task_ids=request.tasks, names=request.names + ) call.result.data = { "configurations": [ @@ -801,10 +792,9 @@ def get_configurations(call: APICall, company_id, request: GetConfigurationsRequ def get_configuration_names( call: APICall, company_id, request: GetConfigurationNamesRequest ): - with translate_errors_context(): - tasks_params = HyperParams.get_configuration_names( - company_id, task_ids=request.tasks, skip_empty=request.skip_empty - ) + tasks_params = HyperParams.get_configuration_names( + company_id, task_ids=request.tasks, skip_empty=request.skip_empty + ) call.result.data = { "configurations": [ @@ -815,31 +805,29 @@ def get_configuration_names( @endpoint("tasks.edit_configuration", request_data_model=EditConfigurationRequest) def edit_configuration(call: APICall, company_id, request: EditConfigurationRequest): - with translate_errors_context(): - call.result.data = { - "updated": HyperParams.edit_configuration( - company_id, - task_id=request.task, - configuration=request.configuration, - replace_configuration=request.replace_configuration, - force=request.force, - ) - } + call.result.data = { + "updated": HyperParams.edit_configuration( + company_id, + task_id=request.task, + configuration=request.configuration, + replace_configuration=request.replace_configuration, + force=request.force, + ) + } @endpoint("tasks.delete_configuration", request_data_model=DeleteConfigurationRequest) def delete_configuration( call: APICall, company_id, request: DeleteConfigurationRequest ): - with translate_errors_context(): - call.result.data = { - "deleted": HyperParams.delete_configuration( - company_id, - task_id=request.task, - configuration=request.configuration, - force=request.force, - ) - } + call.result.data = { + "deleted": HyperParams.delete_configuration( + company_id, + task_id=request.task, + configuration=request.configuration, + force=request.force, + ) + } @endpoint( @@ -1170,15 +1158,14 @@ def ping(_, company_id, request: PingRequest): def add_or_update_artifacts( call: APICall, company_id, request: AddOrUpdateArtifactsRequest ): - with translate_errors_context(): - call.result.data = { - "updated": Artifacts.add_or_update_artifacts( - company_id=company_id, - task_id=request.task, - artifacts=request.artifacts, - force=request.force, - ) - } + call.result.data = { + "updated": Artifacts.add_or_update_artifacts( + company_id=company_id, + task_id=request.task, + artifacts=request.artifacts, + force=request.force, + ) + } @endpoint( @@ -1187,31 +1174,28 @@ def add_or_update_artifacts( request_data_model=DeleteArtifactsRequest, ) def delete_artifacts(call: APICall, company_id, request: DeleteArtifactsRequest): - with translate_errors_context(): - call.result.data = { - "deleted": Artifacts.delete_artifacts( - company_id=company_id, - task_id=request.task, - artifact_ids=request.artifacts, - force=request.force, - ) - } + call.result.data = { + "deleted": Artifacts.delete_artifacts( + company_id=company_id, + task_id=request.task, + artifact_ids=request.artifacts, + force=request.force, + ) + } @endpoint("tasks.make_public", min_version="2.9", request_data_model=MakePublicRequest) def make_public(call: APICall, company_id, request: MakePublicRequest): - with translate_errors_context(): - call.result.data = Task.set_public( - company_id, request.ids, invalid_cls=InvalidTaskId, enabled=True - ) + call.result.data = Task.set_public( + company_id, request.ids, invalid_cls=InvalidTaskId, enabled=True + ) @endpoint("tasks.make_private", min_version="2.9", request_data_model=MakePublicRequest) def make_public(call: APICall, company_id, request: MakePublicRequest): - with translate_errors_context(): - call.result.data = Task.set_public( - company_id, request.ids, invalid_cls=InvalidTaskId, enabled=False - ) + call.result.data = Task.set_public( + company_id, request.ids, invalid_cls=InvalidTaskId, enabled=False + ) @endpoint("tasks.move", request_data_model=MoveRequest)