Allow setting status_message in tasks.update

Optimizations and refactoring
This commit is contained in:
allegroai 2021-07-25 14:35:36 +03:00
parent ec2e071ab7
commit 158da9b480
3 changed files with 107 additions and 120 deletions

View File

@ -234,7 +234,7 @@ class Task(AttributedDocument):
type = StringField(required=True, choices=get_options(TaskType))
status = StringField(default=TaskStatus.created, choices=get_options(TaskStatus))
status_reason = StringField()
status_message = StringField()
status_message = StringField(user_set_allowed=True)
status_changed = DateTimeField()
comment = StringField(user_set_allowed=True)
created = DateTimeField(required=True, user_set_allowed=True)

View File

@ -1,9 +1,12 @@
import random
import string
sys_random = random.SystemRandom()
def get_random_string(length=12, allowed_chars='abcdefghijklmnopqrstuvwxyz'
'ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789'):
def get_random_string(
length: int = 12, allowed_chars: str = string.ascii_letters + string.digits
) -> str:
"""
Returns a securely generated random string.
@ -12,20 +15,20 @@ def get_random_string(length=12, allowed_chars='abcdefghijklmnopqrstuvwxyz'
Taken from the django.utils.crypto module.
"""
return ''.join(sys_random.choice(allowed_chars) for _ in range(length))
return "".join(sys_random.choice(allowed_chars) for _ in range(length))
def get_client_id(length=20):
def get_client_id(length: int = 20) -> str:
"""
Create a random secret key.
Taken from the Django project.
"""
chars = 'ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789'
chars = string.ascii_uppercase + string.digits
return get_random_string(length, chars)
def get_secret_key(length=50):
def get_secret_key(length: int = 50) -> str:
"""
Create a random secret key.
@ -33,5 +36,5 @@ def get_secret_key(length=50):
NOTE: asterisk is not supported due to issues with environment variables containing
asterisks (in case the secret key is stored in an environment variable)
"""
chars = 'abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789'
chars = string.ascii_letters + string.digits
return get_random_string(length, chars)

View File

@ -4,7 +4,6 @@ from functools import partial
from typing import Sequence, Union, Tuple
import attr
import dpath
from mongoengine import EmbeddedDocument, Q
from mongoengine.queryset.transform import COMPARISON_OPERATORS
from pymongo import UpdateOne
@ -220,14 +219,13 @@ def get_all_ex(call: APICall, company_id, _):
call_data = escape_execution_parameters(call)
with translate_errors_context():
with TimingContext("mongo", "task_get_all_ex"):
_process_include_subprojects(call_data)
tasks = Task.get_many_with_join(
company=company_id, query_dict=call_data, allow_public=True,
)
unprepare_from_saved(call, tasks)
call.result.data = {"tasks": tasks}
with TimingContext("mongo", "task_get_all_ex"):
_process_include_subprojects(call_data)
tasks = Task.get_many_with_join(
company=company_id, query_dict=call_data, allow_public=True,
)
unprepare_from_saved(call, tasks)
call.result.data = {"tasks": tasks}
@endpoint("tasks.get_by_id_ex", required_fields=["id"])
@ -236,14 +234,13 @@ def get_by_id_ex(call: APICall, company_id, _):
call_data = escape_execution_parameters(call)
with translate_errors_context():
with TimingContext("mongo", "task_get_by_id_ex"):
tasks = Task.get_many_with_join(
company=company_id, query_dict=call_data, allow_public=True,
)
with TimingContext("mongo", "task_get_by_id_ex"):
tasks = Task.get_many_with_join(
company=company_id, query_dict=call_data, allow_public=True,
)
unprepare_from_saved(call, tasks)
call.result.data = {"tasks": tasks}
unprepare_from_saved(call, tasks)
call.result.data = {"tasks": tasks}
@endpoint("tasks.get_all", required_fields=[])
@ -252,16 +249,15 @@ def get_all(call: APICall, company_id, _):
call_data = escape_execution_parameters(call)
with translate_errors_context():
with TimingContext("mongo", "task_get_all"):
tasks = Task.get_many(
company=company_id,
parameters=call_data,
query_dict=call_data,
allow_public=True,
)
unprepare_from_saved(call, tasks)
call.result.data = {"tasks": tasks}
with TimingContext("mongo", "task_get_all"):
tasks = Task.get_many(
company=company_id,
parameters=call_data,
query_dict=call_data,
allow_public=True,
)
unprepare_from_saved(call, tasks)
call.result.data = {"tasks": tasks}
@endpoint("tasks.get_types", request_data_model=GetTypesRequest)
@ -403,15 +399,12 @@ def prepare_for_save(call: APICall, fields: dict, previous_task: Task = None):
escape_dict_field(fields, path)
# Strip all script fields (remove leading and trailing whitespace chars) to avoid unusable names and paths
for field in task_script_stripped_fields:
try:
path = f"script/{field}"
value = dpath.get(fields, path)
script = fields.get("script")
if script:
for field in task_script_stripped_fields:
value = script.get(field)
if isinstance(value, str):
value = value.strip()
dpath.set(fields, path, value)
except KeyError:
pass
script[field] = value.strip()
return fields
@ -546,10 +539,12 @@ def clone_task(call: APICall, company_id, request: CloneRequest):
}
def prepare_update_fields(call: APICall, task, call_data):
def prepare_update_fields(call: APICall, call_data):
valid_fields = deepcopy(Task.user_set_allowed())
update_fields = {k: v for k, v in create_fields.items() if k in valid_fields}
update_fields["output__error"] = None
update_fields.update(
status=None, status_reason=None, status_message=None, output__error=None
)
t_fields = task_fields
t_fields.add("output__error")
fields = parse_from_call(call_data, update_fields, t_fields)
@ -569,7 +564,7 @@ def update(call: APICall, company_id, req_model: UpdateRequest):
if not task:
raise errors.bad_request.InvalidTaskId(id=task_id)
partial_update_dict, valid_fields = prepare_update_fields(call, task, call.data)
partial_update_dict, valid_fields = prepare_update_fields(call, call.data)
if not partial_update_dict:
return UpdateResponse(updated=0)
@ -642,7 +637,7 @@ def update_batch(call: APICall, company_id, _):
updated_projects = set()
for id, data in items.items():
task = tasks[id]
fields, valid_fields = prepare_update_fields(call, task, data)
fields, valid_fields = prepare_update_fields(call, data)
partial_update_dict = Task.get_safe_update_dict(fields)
if not partial_update_dict:
continue
@ -744,8 +739,7 @@ def edit(call: APICall, company_id, req_model: UpdateRequest):
"tasks.get_hyper_params", request_data_model=GetHyperParamsRequest,
)
def get_hyper_params(call: APICall, company_id, request: GetHyperParamsRequest):
with translate_errors_context():
tasks_params = HyperParams.get_params(company_id, task_ids=request.tasks)
tasks_params = HyperParams.get_params(company_id, task_ids=request.tasks)
call.result.data = {
"params": [{"task": task, **data} for task, data in tasks_params.items()]
@ -754,39 +748,36 @@ def get_hyper_params(call: APICall, company_id, request: GetHyperParamsRequest):
@endpoint("tasks.edit_hyper_params", request_data_model=EditHyperParamsRequest)
def edit_hyper_params(call: APICall, company_id, request: EditHyperParamsRequest):
with translate_errors_context():
call.result.data = {
"updated": HyperParams.edit_params(
company_id,
task_id=request.task,
hyperparams=request.hyperparams,
replace_hyperparams=request.replace_hyperparams,
force=request.force,
)
}
call.result.data = {
"updated": HyperParams.edit_params(
company_id,
task_id=request.task,
hyperparams=request.hyperparams,
replace_hyperparams=request.replace_hyperparams,
force=request.force,
)
}
@endpoint("tasks.delete_hyper_params", request_data_model=DeleteHyperParamsRequest)
def delete_hyper_params(call: APICall, company_id, request: DeleteHyperParamsRequest):
with translate_errors_context():
call.result.data = {
"deleted": HyperParams.delete_params(
company_id,
task_id=request.task,
hyperparams=request.hyperparams,
force=request.force,
)
}
call.result.data = {
"deleted": HyperParams.delete_params(
company_id,
task_id=request.task,
hyperparams=request.hyperparams,
force=request.force,
)
}
@endpoint(
"tasks.get_configurations", request_data_model=GetConfigurationsRequest,
)
def get_configurations(call: APICall, company_id, request: GetConfigurationsRequest):
with translate_errors_context():
tasks_params = HyperParams.get_configurations(
company_id, task_ids=request.tasks, names=request.names
)
tasks_params = HyperParams.get_configurations(
company_id, task_ids=request.tasks, names=request.names
)
call.result.data = {
"configurations": [
@ -801,10 +792,9 @@ def get_configurations(call: APICall, company_id, request: GetConfigurationsRequ
def get_configuration_names(
call: APICall, company_id, request: GetConfigurationNamesRequest
):
with translate_errors_context():
tasks_params = HyperParams.get_configuration_names(
company_id, task_ids=request.tasks, skip_empty=request.skip_empty
)
tasks_params = HyperParams.get_configuration_names(
company_id, task_ids=request.tasks, skip_empty=request.skip_empty
)
call.result.data = {
"configurations": [
@ -815,31 +805,29 @@ def get_configuration_names(
@endpoint("tasks.edit_configuration", request_data_model=EditConfigurationRequest)
def edit_configuration(call: APICall, company_id, request: EditConfigurationRequest):
with translate_errors_context():
call.result.data = {
"updated": HyperParams.edit_configuration(
company_id,
task_id=request.task,
configuration=request.configuration,
replace_configuration=request.replace_configuration,
force=request.force,
)
}
call.result.data = {
"updated": HyperParams.edit_configuration(
company_id,
task_id=request.task,
configuration=request.configuration,
replace_configuration=request.replace_configuration,
force=request.force,
)
}
@endpoint("tasks.delete_configuration", request_data_model=DeleteConfigurationRequest)
def delete_configuration(
call: APICall, company_id, request: DeleteConfigurationRequest
):
with translate_errors_context():
call.result.data = {
"deleted": HyperParams.delete_configuration(
company_id,
task_id=request.task,
configuration=request.configuration,
force=request.force,
)
}
call.result.data = {
"deleted": HyperParams.delete_configuration(
company_id,
task_id=request.task,
configuration=request.configuration,
force=request.force,
)
}
@endpoint(
@ -1170,15 +1158,14 @@ def ping(_, company_id, request: PingRequest):
def add_or_update_artifacts(
call: APICall, company_id, request: AddOrUpdateArtifactsRequest
):
with translate_errors_context():
call.result.data = {
"updated": Artifacts.add_or_update_artifacts(
company_id=company_id,
task_id=request.task,
artifacts=request.artifacts,
force=request.force,
)
}
call.result.data = {
"updated": Artifacts.add_or_update_artifacts(
company_id=company_id,
task_id=request.task,
artifacts=request.artifacts,
force=request.force,
)
}
@endpoint(
@ -1187,31 +1174,28 @@ def add_or_update_artifacts(
request_data_model=DeleteArtifactsRequest,
)
def delete_artifacts(call: APICall, company_id, request: DeleteArtifactsRequest):
with translate_errors_context():
call.result.data = {
"deleted": Artifacts.delete_artifacts(
company_id=company_id,
task_id=request.task,
artifact_ids=request.artifacts,
force=request.force,
)
}
call.result.data = {
"deleted": Artifacts.delete_artifacts(
company_id=company_id,
task_id=request.task,
artifact_ids=request.artifacts,
force=request.force,
)
}
@endpoint("tasks.make_public", min_version="2.9", request_data_model=MakePublicRequest)
def make_public(call: APICall, company_id, request: MakePublicRequest):
with translate_errors_context():
call.result.data = Task.set_public(
company_id, request.ids, invalid_cls=InvalidTaskId, enabled=True
)
call.result.data = Task.set_public(
company_id, request.ids, invalid_cls=InvalidTaskId, enabled=True
)
@endpoint("tasks.make_private", min_version="2.9", request_data_model=MakePublicRequest)
def make_public(call: APICall, company_id, request: MakePublicRequest):
with translate_errors_context():
call.result.data = Task.set_public(
company_id, request.ids, invalid_cls=InvalidTaskId, enabled=False
)
call.result.data = Task.set_public(
company_id, request.ids, invalid_cls=InvalidTaskId, enabled=False
)
@endpoint("tasks.move", request_data_model=MoveRequest)