Allow setting status_message in tasks.update

Optimizations and refactoring
This commit is contained in:
allegroai 2021-07-25 14:35:36 +03:00
parent ec2e071ab7
commit 158da9b480
3 changed files with 107 additions and 120 deletions

View File

@ -234,7 +234,7 @@ class Task(AttributedDocument):
type = StringField(required=True, choices=get_options(TaskType)) type = StringField(required=True, choices=get_options(TaskType))
status = StringField(default=TaskStatus.created, choices=get_options(TaskStatus)) status = StringField(default=TaskStatus.created, choices=get_options(TaskStatus))
status_reason = StringField() status_reason = StringField()
status_message = StringField() status_message = StringField(user_set_allowed=True)
status_changed = DateTimeField() status_changed = DateTimeField()
comment = StringField(user_set_allowed=True) comment = StringField(user_set_allowed=True)
created = DateTimeField(required=True, user_set_allowed=True) created = DateTimeField(required=True, user_set_allowed=True)

View File

@ -1,9 +1,12 @@
import random import random
import string
sys_random = random.SystemRandom() sys_random = random.SystemRandom()
def get_random_string(length=12, allowed_chars='abcdefghijklmnopqrstuvwxyz' def get_random_string(
'ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789'): length: int = 12, allowed_chars: str = string.ascii_letters + string.digits
) -> str:
""" """
Returns a securely generated random string. Returns a securely generated random string.
@ -12,20 +15,20 @@ def get_random_string(length=12, allowed_chars='abcdefghijklmnopqrstuvwxyz'
Taken from the django.utils.crypto module. Taken from the django.utils.crypto module.
""" """
return ''.join(sys_random.choice(allowed_chars) for _ in range(length)) return "".join(sys_random.choice(allowed_chars) for _ in range(length))
def get_client_id(length=20): def get_client_id(length: int = 20) -> str:
""" """
Create a random secret key. Create a random secret key.
Taken from the Django project. Taken from the Django project.
""" """
chars = 'ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789' chars = string.ascii_uppercase + string.digits
return get_random_string(length, chars) return get_random_string(length, chars)
def get_secret_key(length=50): def get_secret_key(length: int = 50) -> str:
""" """
Create a random secret key. Create a random secret key.
@ -33,5 +36,5 @@ def get_secret_key(length=50):
NOTE: asterisk is not supported due to issues with environment variables containing NOTE: asterisk is not supported due to issues with environment variables containing
asterisks (in case the secret key is stored in an environment variable) asterisks (in case the secret key is stored in an environment variable)
""" """
chars = 'abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789' chars = string.ascii_letters + string.digits
return get_random_string(length, chars) return get_random_string(length, chars)

View File

@ -4,7 +4,6 @@ from functools import partial
from typing import Sequence, Union, Tuple from typing import Sequence, Union, Tuple
import attr import attr
import dpath
from mongoengine import EmbeddedDocument, Q from mongoengine import EmbeddedDocument, Q
from mongoengine.queryset.transform import COMPARISON_OPERATORS from mongoengine.queryset.transform import COMPARISON_OPERATORS
from pymongo import UpdateOne from pymongo import UpdateOne
@ -220,7 +219,6 @@ def get_all_ex(call: APICall, company_id, _):
call_data = escape_execution_parameters(call) call_data = escape_execution_parameters(call)
with translate_errors_context():
with TimingContext("mongo", "task_get_all_ex"): with TimingContext("mongo", "task_get_all_ex"):
_process_include_subprojects(call_data) _process_include_subprojects(call_data)
tasks = Task.get_many_with_join( tasks = Task.get_many_with_join(
@ -236,7 +234,6 @@ def get_by_id_ex(call: APICall, company_id, _):
call_data = escape_execution_parameters(call) call_data = escape_execution_parameters(call)
with translate_errors_context():
with TimingContext("mongo", "task_get_by_id_ex"): with TimingContext("mongo", "task_get_by_id_ex"):
tasks = Task.get_many_with_join( tasks = Task.get_many_with_join(
company=company_id, query_dict=call_data, allow_public=True, company=company_id, query_dict=call_data, allow_public=True,
@ -252,7 +249,6 @@ def get_all(call: APICall, company_id, _):
call_data = escape_execution_parameters(call) call_data = escape_execution_parameters(call)
with translate_errors_context():
with TimingContext("mongo", "task_get_all"): with TimingContext("mongo", "task_get_all"):
tasks = Task.get_many( tasks = Task.get_many(
company=company_id, company=company_id,
@ -403,15 +399,12 @@ def prepare_for_save(call: APICall, fields: dict, previous_task: Task = None):
escape_dict_field(fields, path) escape_dict_field(fields, path)
# Strip all script fields (remove leading and trailing whitespace chars) to avoid unusable names and paths # Strip all script fields (remove leading and trailing whitespace chars) to avoid unusable names and paths
script = fields.get("script")
if script:
for field in task_script_stripped_fields: for field in task_script_stripped_fields:
try: value = script.get(field)
path = f"script/{field}"
value = dpath.get(fields, path)
if isinstance(value, str): if isinstance(value, str):
value = value.strip() script[field] = value.strip()
dpath.set(fields, path, value)
except KeyError:
pass
return fields return fields
@ -546,10 +539,12 @@ def clone_task(call: APICall, company_id, request: CloneRequest):
} }
def prepare_update_fields(call: APICall, task, call_data): def prepare_update_fields(call: APICall, call_data):
valid_fields = deepcopy(Task.user_set_allowed()) valid_fields = deepcopy(Task.user_set_allowed())
update_fields = {k: v for k, v in create_fields.items() if k in valid_fields} update_fields = {k: v for k, v in create_fields.items() if k in valid_fields}
update_fields["output__error"] = None update_fields.update(
status=None, status_reason=None, status_message=None, output__error=None
)
t_fields = task_fields t_fields = task_fields
t_fields.add("output__error") t_fields.add("output__error")
fields = parse_from_call(call_data, update_fields, t_fields) fields = parse_from_call(call_data, update_fields, t_fields)
@ -569,7 +564,7 @@ def update(call: APICall, company_id, req_model: UpdateRequest):
if not task: if not task:
raise errors.bad_request.InvalidTaskId(id=task_id) raise errors.bad_request.InvalidTaskId(id=task_id)
partial_update_dict, valid_fields = prepare_update_fields(call, task, call.data) partial_update_dict, valid_fields = prepare_update_fields(call, call.data)
if not partial_update_dict: if not partial_update_dict:
return UpdateResponse(updated=0) return UpdateResponse(updated=0)
@ -642,7 +637,7 @@ def update_batch(call: APICall, company_id, _):
updated_projects = set() updated_projects = set()
for id, data in items.items(): for id, data in items.items():
task = tasks[id] task = tasks[id]
fields, valid_fields = prepare_update_fields(call, task, data) fields, valid_fields = prepare_update_fields(call, data)
partial_update_dict = Task.get_safe_update_dict(fields) partial_update_dict = Task.get_safe_update_dict(fields)
if not partial_update_dict: if not partial_update_dict:
continue continue
@ -744,7 +739,6 @@ def edit(call: APICall, company_id, req_model: UpdateRequest):
"tasks.get_hyper_params", request_data_model=GetHyperParamsRequest, "tasks.get_hyper_params", request_data_model=GetHyperParamsRequest,
) )
def get_hyper_params(call: APICall, company_id, request: GetHyperParamsRequest): def get_hyper_params(call: APICall, company_id, request: GetHyperParamsRequest):
with translate_errors_context():
tasks_params = HyperParams.get_params(company_id, task_ids=request.tasks) tasks_params = HyperParams.get_params(company_id, task_ids=request.tasks)
call.result.data = { call.result.data = {
@ -754,7 +748,6 @@ def get_hyper_params(call: APICall, company_id, request: GetHyperParamsRequest):
@endpoint("tasks.edit_hyper_params", request_data_model=EditHyperParamsRequest) @endpoint("tasks.edit_hyper_params", request_data_model=EditHyperParamsRequest)
def edit_hyper_params(call: APICall, company_id, request: EditHyperParamsRequest): def edit_hyper_params(call: APICall, company_id, request: EditHyperParamsRequest):
with translate_errors_context():
call.result.data = { call.result.data = {
"updated": HyperParams.edit_params( "updated": HyperParams.edit_params(
company_id, company_id,
@ -768,7 +761,6 @@ def edit_hyper_params(call: APICall, company_id, request: EditHyperParamsRequest
@endpoint("tasks.delete_hyper_params", request_data_model=DeleteHyperParamsRequest) @endpoint("tasks.delete_hyper_params", request_data_model=DeleteHyperParamsRequest)
def delete_hyper_params(call: APICall, company_id, request: DeleteHyperParamsRequest): def delete_hyper_params(call: APICall, company_id, request: DeleteHyperParamsRequest):
with translate_errors_context():
call.result.data = { call.result.data = {
"deleted": HyperParams.delete_params( "deleted": HyperParams.delete_params(
company_id, company_id,
@ -783,7 +775,6 @@ def delete_hyper_params(call: APICall, company_id, request: DeleteHyperParamsReq
"tasks.get_configurations", request_data_model=GetConfigurationsRequest, "tasks.get_configurations", request_data_model=GetConfigurationsRequest,
) )
def get_configurations(call: APICall, company_id, request: GetConfigurationsRequest): def get_configurations(call: APICall, company_id, request: GetConfigurationsRequest):
with translate_errors_context():
tasks_params = HyperParams.get_configurations( tasks_params = HyperParams.get_configurations(
company_id, task_ids=request.tasks, names=request.names company_id, task_ids=request.tasks, names=request.names
) )
@ -801,7 +792,6 @@ def get_configurations(call: APICall, company_id, request: GetConfigurationsRequ
def get_configuration_names( def get_configuration_names(
call: APICall, company_id, request: GetConfigurationNamesRequest call: APICall, company_id, request: GetConfigurationNamesRequest
): ):
with translate_errors_context():
tasks_params = HyperParams.get_configuration_names( tasks_params = HyperParams.get_configuration_names(
company_id, task_ids=request.tasks, skip_empty=request.skip_empty company_id, task_ids=request.tasks, skip_empty=request.skip_empty
) )
@ -815,7 +805,6 @@ def get_configuration_names(
@endpoint("tasks.edit_configuration", request_data_model=EditConfigurationRequest) @endpoint("tasks.edit_configuration", request_data_model=EditConfigurationRequest)
def edit_configuration(call: APICall, company_id, request: EditConfigurationRequest): def edit_configuration(call: APICall, company_id, request: EditConfigurationRequest):
with translate_errors_context():
call.result.data = { call.result.data = {
"updated": HyperParams.edit_configuration( "updated": HyperParams.edit_configuration(
company_id, company_id,
@ -831,7 +820,6 @@ def edit_configuration(call: APICall, company_id, request: EditConfigurationRequ
def delete_configuration( def delete_configuration(
call: APICall, company_id, request: DeleteConfigurationRequest call: APICall, company_id, request: DeleteConfigurationRequest
): ):
with translate_errors_context():
call.result.data = { call.result.data = {
"deleted": HyperParams.delete_configuration( "deleted": HyperParams.delete_configuration(
company_id, company_id,
@ -1170,7 +1158,6 @@ def ping(_, company_id, request: PingRequest):
def add_or_update_artifacts( def add_or_update_artifacts(
call: APICall, company_id, request: AddOrUpdateArtifactsRequest call: APICall, company_id, request: AddOrUpdateArtifactsRequest
): ):
with translate_errors_context():
call.result.data = { call.result.data = {
"updated": Artifacts.add_or_update_artifacts( "updated": Artifacts.add_or_update_artifacts(
company_id=company_id, company_id=company_id,
@ -1187,7 +1174,6 @@ def add_or_update_artifacts(
request_data_model=DeleteArtifactsRequest, request_data_model=DeleteArtifactsRequest,
) )
def delete_artifacts(call: APICall, company_id, request: DeleteArtifactsRequest): def delete_artifacts(call: APICall, company_id, request: DeleteArtifactsRequest):
with translate_errors_context():
call.result.data = { call.result.data = {
"deleted": Artifacts.delete_artifacts( "deleted": Artifacts.delete_artifacts(
company_id=company_id, company_id=company_id,
@ -1200,7 +1186,6 @@ def delete_artifacts(call: APICall, company_id, request: DeleteArtifactsRequest)
@endpoint("tasks.make_public", min_version="2.9", request_data_model=MakePublicRequest) @endpoint("tasks.make_public", min_version="2.9", request_data_model=MakePublicRequest)
def make_public(call: APICall, company_id, request: MakePublicRequest): def make_public(call: APICall, company_id, request: MakePublicRequest):
with translate_errors_context():
call.result.data = Task.set_public( call.result.data = Task.set_public(
company_id, request.ids, invalid_cls=InvalidTaskId, enabled=True company_id, request.ids, invalid_cls=InvalidTaskId, enabled=True
) )
@ -1208,7 +1193,6 @@ def make_public(call: APICall, company_id, request: MakePublicRequest):
@endpoint("tasks.make_private", min_version="2.9", request_data_model=MakePublicRequest) @endpoint("tasks.make_private", min_version="2.9", request_data_model=MakePublicRequest)
def make_public(call: APICall, company_id, request: MakePublicRequest): def make_public(call: APICall, company_id, request: MakePublicRequest):
with translate_errors_context():
call.result.data = Task.set_public( call.result.data = Task.set_public(
company_id, request.ids, invalid_cls=InvalidTaskId, enabled=False company_id, request.ids, invalid_cls=InvalidTaskId, enabled=False
) )