Add support for allow_public flag in get_all_ex endpoint

Add `last_changed_by` field on task updates
Fix reports support
This commit is contained in:
allegroai 2022-12-21 18:32:56 +02:00
parent c7cd949fd0
commit ae4c33fa0e
23 changed files with 256 additions and 76 deletions

View File

@ -79,3 +79,4 @@ class AddOrUpdateMetadataRequest(AddOrUpdateMetadata):
class ModelsGetRequest(models.Base): class ModelsGetRequest(models.Base):
include_stats = fields.BoolField(default=False) include_stats = fields.BoolField(default=False)
allow_public = fields.BoolField(default=True)

View File

@ -19,5 +19,7 @@ class EntitiesCountRequest(models.Base):
models = DictField() models = DictField()
pipelines = DictField() pipelines = DictField()
datasets = DictField() datasets = DictField()
reports = DictField()
active_users = fields.ListField(str) active_users = fields.ListField(str)
search_hidden = fields.BoolField(default=False) search_hidden = fields.BoolField(default=False)
allow_public = fields.BoolField(default=True)

View File

@ -62,8 +62,9 @@ class ProjectsGetRequest(models.Base):
include_stats_filter = DictField() include_stats_filter = DictField()
stats_with_children = fields.BoolField(default=True) stats_with_children = fields.BoolField(default=True)
stats_for_state = ActualEnumField(EntityVisibility, default=EntityVisibility.active) stats_for_state = ActualEnumField(EntityVisibility, default=EntityVisibility.active)
non_public = fields.BoolField(default=False) non_public = fields.BoolField(default=False) # legacy, use allow_public instead
active_users = fields.ListField(str) active_users = fields.ListField(str)
check_own_contents = fields.BoolField(default=False) check_own_contents = fields.BoolField(default=False)
shallow_search = fields.BoolField(default=False) shallow_search = fields.BoolField(default=False)
search_hidden = fields.BoolField(default=False) search_hidden = fields.BoolField(default=False)
allow_public = fields.BoolField(default=True)

View File

@ -318,3 +318,8 @@ class DeleteModelsRequest(TaskRequest):
models: Sequence[ModelItemKey] = ListField( models: Sequence[ModelItemKey] = ListField(
[ModelItemKey], validators=Length(minimum_value=1) [ModelItemKey], validators=Length(minimum_value=1)
) )
class GetAllReq(models.Base):
allow_public = BoolField(default=True)
search_hidden = BoolField(default=False)

View File

@ -58,8 +58,9 @@ class ModelBLL:
cls, cls,
model_id: str, model_id: str,
company_id: str, company_id: str,
user_id: str,
force_publish_task: bool = False, force_publish_task: bool = False,
publish_task_func: Callable[[str, str, bool], dict] = None, publish_task_func: Callable[[str, str, str, bool], dict] = None,
) -> Tuple[int, ModelTaskPublishResponse]: ) -> Tuple[int, ModelTaskPublishResponse]:
model = cls.get_company_model_by_id(company_id=company_id, model_id=model_id) model = cls.get_company_model_by_id(company_id=company_id, model_id=model_id)
if model.ready: if model.ready:
@ -74,7 +75,7 @@ class ModelBLL:
) )
if task and task.status != TaskStatus.published: if task and task.status != TaskStatus.published:
task_publish_res = publish_task_func( task_publish_res = publish_task_func(
model.task, company_id, force_publish_task model.task, company_id, user_id, force_publish_task
) )
published_task = ModelTaskPublishResponse( published_task = ModelTaskPublishResponse(
id=model.task, data=task_publish_res id=model.task, data=task_publish_res

View File

@ -133,7 +133,7 @@ class QueueBLL(object):
self.get_by_id(company_id=company_id, queue_id=queue_id, only=("id",)) self.get_by_id(company_id=company_id, queue_id=queue_id, only=("id",))
return Queue.safe_update(company_id, queue_id, update_fields) return Queue.safe_update(company_id, queue_id, update_fields)
def delete(self, company_id: str, queue_id: str, force: bool) -> None: def delete(self, company_id: str, user_id: str, queue_id: str, force: bool) -> None:
""" """
Delete the queue Delete the queue
:raise errors.bad_request.InvalidQueueId: if the queue is not found :raise errors.bad_request.InvalidQueueId: if the queue is not found
@ -163,6 +163,7 @@ class QueueBLL(object):
new_status=task.enqueue_status or TaskStatus.created, new_status=task.enqueue_status or TaskStatus.created,
status_reason="Queue was deleted", status_reason="Queue was deleted",
status_message="", status_message="",
user_id=user_id,
).execute(enqueue_status=None) ).execute(enqueue_status=None)
except Exception as ex: except Exception as ex:
log.exception( log.exception(

View File

@ -48,6 +48,7 @@ class Artifacts:
def add_or_update_artifacts( def add_or_update_artifacts(
cls, cls,
company_id: str, company_id: str,
user_id: str,
task_id: str, task_id: str,
artifacts: Sequence[ApiArtifact], artifacts: Sequence[ApiArtifact],
force: bool, force: bool,
@ -63,12 +64,13 @@ class Artifacts:
f"set__execution__artifacts__{mongoengine_safe(name)}": value f"set__execution__artifacts__{mongoengine_safe(name)}": value
for name, value in artifacts.items() for name, value in artifacts.items()
} }
return update_task(task, update_cmds=update_cmds) return update_task(task, user_id=user_id, update_cmds=update_cmds)
@classmethod @classmethod
def delete_artifacts( def delete_artifacts(
cls, cls,
company_id: str, company_id: str,
user_id: str,
task_id: str, task_id: str,
artifact_ids: Sequence[ArtifactId], artifact_ids: Sequence[ArtifactId],
force: bool, force: bool,
@ -83,4 +85,4 @@ class Artifacts:
f"unset__execution__artifacts__{id_}": 1 for id_ in set(artifact_ids) f"unset__execution__artifacts__{id_}": 1 for id_ in set(artifact_ids)
} }
return update_task(task, update_cmds=delete_cmds) return update_task(task, user_id=user_id, update_cmds=delete_cmds)

View File

@ -63,6 +63,7 @@ class HyperParams:
def delete_params( def delete_params(
cls, cls,
company_id: str, company_id: str,
user_id: str,
task_id: str, task_id: str,
hyperparams: Sequence[HyperParamKey], hyperparams: Sequence[HyperParamKey],
force: bool, force: bool,
@ -94,13 +95,17 @@ class HyperParams:
delete_cmds[f"unset__hyperparams__{section}__{name}"] = 1 delete_cmds[f"unset__hyperparams__{section}__{name}"] = 1
return update_task( return update_task(
task, update_cmds=delete_cmds, set_last_update=not properties_only task,
user_id=user_id,
update_cmds=delete_cmds,
set_last_update=not properties_only,
) )
@classmethod @classmethod
def edit_params( def edit_params(
cls, cls,
company_id: str, company_id: str,
user_id: str,
task_id: str, task_id: str,
hyperparams: Sequence[HyperParamItem], hyperparams: Sequence[HyperParamItem],
replace_hyperparams: str, replace_hyperparams: str,
@ -129,7 +134,10 @@ class HyperParams:
] = value ] = value
return update_task( return update_task(
task, update_cmds=update_cmds, set_last_update=not properties_only task,
user_id=user_id,
update_cmds=update_cmds,
set_last_update=not properties_only,
) )
@classmethod @classmethod
@ -201,6 +209,7 @@ class HyperParams:
def edit_configuration( def edit_configuration(
cls, cls,
company_id: str, company_id: str,
user_id: str,
task_id: str, task_id: str,
configuration: Sequence[Configuration], configuration: Sequence[Configuration],
replace_configuration: bool, replace_configuration: bool,
@ -219,11 +228,16 @@ class HyperParams:
for name, value in configuration.items(): for name, value in configuration.items():
update_cmds[f"set__configuration__{mongoengine_safe(name)}"] = value update_cmds[f"set__configuration__{mongoengine_safe(name)}"] = value
return update_task(task, update_cmds=update_cmds) return update_task(task, user_id=user_id, update_cmds=update_cmds)
@classmethod @classmethod
def delete_configuration( def delete_configuration(
cls, company_id: str, task_id: str, configuration: Sequence[str], force: bool cls,
company_id: str,
user_id: str,
task_id: str,
configuration: Sequence[str],
force: bool,
) -> int: ) -> int:
task = get_task_for_update(company_id=company_id, task_id=task_id, force=force) task = get_task_for_update(company_id=company_id, task_id=task_id, force=force)
@ -232,4 +246,4 @@ class HyperParams:
for name in set(configuration) for name in set(configuration)
} }
return update_task(task, update_cmds=delete_cmds) return update_task(task, user_id=user_id, update_cmds=delete_cmds)

View File

@ -33,7 +33,6 @@ from apiserver.database.model import EntityVisibility
from apiserver.database.utils import get_company_or_none_constraint, id as create_id from apiserver.database.utils import get_company_or_none_constraint, id as create_id
from apiserver.es_factory import es_factory from apiserver.es_factory import es_factory
from apiserver.redis_manager import redman from apiserver.redis_manager import redman
from apiserver.service_repo import APICall
from apiserver.services.utils import validate_tags, escape_dict_field, escape_dict from apiserver.services.utils import validate_tags, escape_dict_field, escape_dict
from .artifacts import artifacts_prepare_for_save from .artifacts import artifacts_prepare_for_save
from .param_utils import params_prepare_for_save from .param_utils import params_prepare_for_save
@ -137,6 +136,7 @@ class TaskBLL:
created=now, created=now,
last_update=now, last_update=now,
last_change=now, last_change=now,
last_changed_by=user,
**fields, **fields,
) )
@ -268,6 +268,7 @@ class TaskBLL:
created=now, created=now,
last_update=now, last_update=now,
last_change=now, last_change=now,
last_changed_by=user_id,
name=name or task.name, name=name or task.name,
comment=comment or task.comment, comment=comment or task.comment,
parent=parent or parent_task, parent=parent or parent_task,
@ -462,7 +463,12 @@ class TaskBLL:
@classmethod @classmethod
def dequeue_and_change_status( def dequeue_and_change_status(
cls, task: Task, company_id: str, status_message: str, status_reason: str, cls,
task: Task,
company_id: str,
user_id: str,
status_message: str,
status_reason: str,
): ):
try: try:
cls.dequeue(task, company_id) cls.dequeue(task, company_id)
@ -475,6 +481,7 @@ class TaskBLL:
new_status=task.enqueue_status or TaskStatus.created, new_status=task.enqueue_status or TaskStatus.created,
status_reason=status_reason, status_reason=status_reason,
status_message=status_message, status_message=status_message,
user_id=user_id,
).execute(enqueue_status=None) ).execute(enqueue_status=None)
@classmethod @classmethod

View File

@ -30,7 +30,11 @@ queue_bll = QueueBLL()
def archive_task( def archive_task(
task: Union[str, Task], company_id: str, status_message: str, status_reason: str, task: Union[str, Task],
company_id: str,
user_id: str,
status_message: str,
status_reason: str,
) -> int: ) -> int:
""" """
Deque and archive task Deque and archive task
@ -52,7 +56,11 @@ def archive_task(
) )
try: try:
TaskBLL.dequeue_and_change_status( TaskBLL.dequeue_and_change_status(
task, company_id, status_message, status_reason, task,
company_id=company_id,
user_id=user_id,
status_message=status_message,
status_reason=status_reason,
) )
except APIError: except APIError:
# dequeue may fail if the task was not enqueued # dequeue may fail if the task was not enqueued
@ -63,11 +71,12 @@ def archive_task(
status_reason=status_reason, status_reason=status_reason,
add_to_set__system_tags=EntityVisibility.archived.value, add_to_set__system_tags=EntityVisibility.archived.value,
last_change=datetime.utcnow(), last_change=datetime.utcnow(),
last_changed_by=user_id,
) )
def unarchive_task( def unarchive_task(
task: str, company_id: str, status_message: str, status_reason: str, task: str, company_id: str, user_id: str, status_message: str, status_reason: str,
) -> int: ) -> int:
""" """
Unarchive task. Return 1 if successful Unarchive task. Return 1 if successful
@ -80,11 +89,16 @@ def unarchive_task(
status_reason=status_reason, status_reason=status_reason,
pull__system_tags=EntityVisibility.archived.value, pull__system_tags=EntityVisibility.archived.value,
last_change=datetime.utcnow(), last_change=datetime.utcnow(),
last_changed_by=user_id,
) )
def dequeue_task( def dequeue_task(
task_id: str, company_id: str, status_message: str, status_reason: str, task_id: str,
company_id: str,
user_id: str,
status_message: str,
status_reason: str,
) -> Tuple[int, dict]: ) -> Tuple[int, dict]:
query = dict(id=task_id, company=company_id) query = dict(id=task_id, company=company_id)
task = Task.get_for_writing(**query) task = Task.get_for_writing(**query)
@ -92,7 +106,11 @@ def dequeue_task(
raise errors.bad_request.InvalidTaskId(**query) raise errors.bad_request.InvalidTaskId(**query)
res = TaskBLL.dequeue_and_change_status( res = TaskBLL.dequeue_and_change_status(
task, company_id, status_message=status_message, status_reason=status_reason, task,
company_id=company_id,
user_id=user_id,
status_message=status_message,
status_reason=status_reason,
) )
return 1, res return 1, res
@ -100,6 +118,7 @@ def dequeue_task(
def enqueue_task( def enqueue_task(
task_id: str, task_id: str,
company_id: str, company_id: str,
user_id: str,
queue_id: str, queue_id: str,
status_message: str, status_message: str,
status_reason: str, status_reason: str,
@ -139,6 +158,7 @@ def enqueue_task(
status_message=status_message, status_message=status_message,
allow_same_state_transition=False, allow_same_state_transition=False,
force=force, force=force,
user_id=user_id,
).execute(enqueue_status=task.status) ).execute(enqueue_status=task.status)
try: try:
@ -151,6 +171,7 @@ def enqueue_task(
new_status=task.status, new_status=task.status,
force=True, force=True,
status_reason="failed enqueueing", status_reason="failed enqueueing",
user_id=user_id,
).execute(enqueue_status=None) ).execute(enqueue_status=None)
raise raise
@ -220,6 +241,7 @@ def delete_task(
TaskBLL.dequeue_and_change_status( TaskBLL.dequeue_and_change_status(
task, task,
company_id=company_id, company_id=company_id,
user_id=user_id,
status_message=status_message, status_message=status_message,
status_reason=status_reason, status_reason=status_reason,
) )
@ -319,6 +341,7 @@ def reset_task(
force=force, force=force,
status_reason="reset", status_reason="reset",
status_message="reset", status_message="reset",
user_id=user_id,
).execute( ).execute(
started=None, started=None,
completed=None, completed=None,
@ -334,8 +357,9 @@ def reset_task(
def publish_task( def publish_task(
task_id: str, task_id: str,
company_id: str, company_id: str,
user_id: str,
force: bool, force: bool,
publish_model_func: Callable[[str, str], Any] = None, publish_model_func: Callable[[str, str, str], Any] = None,
status_message: str = "", status_message: str = "",
status_reason: str = "", status_reason: str = "",
) -> dict: ) -> dict:
@ -363,7 +387,7 @@ def publish_task(
.first() .first()
) )
if model and not model.ready: if model and not model.ready:
publish_model_func(model.id, company_id) publish_model_func(model.id, company_id, user_id)
# set task status to published, and update (or set) it's new output (view and models) # set task status to published, and update (or set) it's new output (view and models)
return ChangeStatusRequest( return ChangeStatusRequest(
@ -372,6 +396,7 @@ def publish_task(
force=force, force=force,
status_reason=status_reason, status_reason=status_reason,
status_message=status_message, status_message=status_message,
user_id=user_id,
).execute(published=datetime.utcnow(), output=output) ).execute(published=datetime.utcnow(), output=output)
except Exception as ex: except Exception as ex:
@ -384,7 +409,12 @@ def publish_task(
def stop_task( def stop_task(
task_id: str, company_id: str, user_name: str, status_reason: str, force: bool, task_id: str,
company_id: str,
user_id: str,
user_name: str,
status_reason: str,
force: bool,
) -> dict: ) -> dict:
""" """
Stop a running task. Requires task status 'in_progress' and Stop a running task. Requires task status 'in_progress' and
@ -446,4 +476,5 @@ def stop_task(
status_reason=status_reason, status_reason=status_reason,
status_message=status_message, status_message=status_message,
force=force, force=force,
user_id=user_id,
).execute() ).execute()

View File

@ -26,6 +26,7 @@ class ChangeStatusRequest(object):
force = attr.ib(type=bool, default=False) force = attr.ib(type=bool, default=False)
allow_same_state_transition = attr.ib(type=bool, default=True) allow_same_state_transition = attr.ib(type=bool, default=True)
current_status_override = attr.ib(default=None) current_status_override = attr.ib(default=None)
user_id = attr.ib(type=str, default=None)
def execute(self, **kwargs): def execute(self, **kwargs):
current_status = self.current_status_override or self.task.status current_status = self.current_status_override or self.task.status
@ -44,6 +45,7 @@ class ChangeStatusRequest(object):
status_changed=now, status_changed=now,
last_update=now, last_update=now,
last_change=now, last_change=now,
last_changed_by=self.user_id,
) )
if self.new_status == TaskStatus.queued: if self.new_status == TaskStatus.queued:
@ -165,7 +167,7 @@ def update_project_time(project_ids: Union[str, Sequence[str]]):
def get_task_for_update( def get_task_for_update(
company_id: str, task_id: str, allow_all_statuses: bool = False, force: bool = False company_id: str, task_id: str, allow_all_statuses: bool = False, force: bool = False
) -> Task: ) -> Task:
""" """
Loads only task id and return the task only if it is updatable (status == 'created') Loads only task id and return the task only if it is updatable (status == 'created')
@ -187,9 +189,9 @@ def get_task_for_update(
return task return task
def update_task(task: Task, update_cmds: dict, set_last_update: bool = True): def update_task(task: Task, user_id: str, update_cmds: dict, set_last_update: bool = True):
now = datetime.utcnow() now = datetime.utcnow()
last_updates = dict(last_change=now) last_updates = dict(last_change=now, last_changed_by=user_id)
if set_last_update: if set_last_update:
last_updates.update(last_update=now) last_updates.update(last_update=now)
return task.update(**update_cmds, **last_updates) return task.update(**update_cmds, **last_updates)

View File

@ -196,6 +196,7 @@ class Task(AttributedDocument):
"$name", "$name",
"$id", "$id",
"$comment", "$comment",
"$report",
"$models.input.model", "$models.input.model",
"$models.output.model", "$models.output.model",
"$script.repository", "$script.repository",
@ -206,6 +207,7 @@ class Task(AttributedDocument):
"name": 10, "name": 10,
"id": 10, "id": 10,
"comment": 10, "comment": 10,
"report": 10,
"models.output.model": 2, "models.output.model": 2,
"models.input.model": 2, "models.input.model": 2,
"script.repository": 1, "script.repository": 1,
@ -228,7 +230,7 @@ class Task(AttributedDocument):
), ),
range_fields=("started", "active_duration", "last_metrics.*", "last_iteration"), range_fields=("started", "active_duration", "last_metrics.*", "last_iteration"),
datetime_fields=("status_changed", "last_update"), datetime_fields=("status_changed", "last_update"),
pattern_fields=("name", "comment"), pattern_fields=("name", "comment", "report"),
) )
id = StringField(primary_key=True) id = StringField(primary_key=True)
@ -242,6 +244,7 @@ class Task(AttributedDocument):
status_message = StringField(user_set_allowed=True) status_message = StringField(user_set_allowed=True)
status_changed = DateTimeField() status_changed = DateTimeField()
comment = StringField(user_set_allowed=True) comment = StringField(user_set_allowed=True)
report = StringField()
created = DateTimeField(required=True, user_set_allowed=True) created = DateTimeField(required=True, user_set_allowed=True)
started = DateTimeField() started = DateTimeField()
completed = DateTimeField() completed = DateTimeField()
@ -272,6 +275,7 @@ class Task(AttributedDocument):
enqueue_status = StringField( enqueue_status = StringField(
choices=get_options(TaskStatus), exclude_by_default=True choices=get_options(TaskStatus), exclude_by_default=True
) )
last_changed_by = StringField()
def get_index_company(self) -> str: def get_index_company(self) -> str:
""" """

View File

@ -0,0 +1,17 @@
import logging as log
from pymongo.collection import Collection
from pymongo.database import Database
from pymongo.errors import OperationFailure
def migrate_backend(db: Database):
"""
Drop task text index so that the new one including reports field is created
"""
tasks: Collection = db["task"]
try:
tasks.drop_index("backend-db.task.main_text_index")
except OperationFailure as ex:
log.warning(f"Could not delete task text index due to: {str(ex)}")
pass

View File

@ -241,6 +241,15 @@ get_all_ex {
default: false default: false
} }
} }
"999.0": ${get_all_ex."2.20"} {
request.properties {
allow_public {
description: "Allow public models to be returned in the results"
type: boolean
default: true
}
}
}
} }
get_all { get_all {
"2.1" { "2.1" {

View File

@ -176,4 +176,24 @@ get_entities_count {
} }
} }
} }
"999.0": ${get_entities_count."2.22"} {
request.properties {
reports {
type: object
additionalProperties: true
description: Search criteria for reports
}
allow_public {
description: "Allow public entities to be counted in the results"
type: boolean
default: true
}
}
response.properties {
reports {
type: integer
description: The number of reports matching the criteria
}
}
}
} }

View File

@ -611,6 +611,15 @@ get_all_ex {
default: false default: false
} }
} }
"999.0": ${get_all_ex."2.20"} {
request.properties {
allow_public {
description: "Allow public projects to be returned in the results"
type: boolean
default: true
}
}
}
} }
update { update {
"2.1" { "2.1" {

View File

@ -181,6 +181,15 @@ get_all_ex {
description: "Scroll ID that can be used with the next calls to get_all_ex to retrieve more data" description: "Scroll ID that can be used with the next calls to get_all_ex to retrieve more data"
} }
} }
"999.0": ${get_all_ex."2.15"} {
request.properties {
allow_public {
description: "Allow public tasks to be returned in the results"
type: boolean
default: true
}
}
}
} }
get_all { get_all {
"2.1" { "2.1" {

View File

@ -116,7 +116,7 @@ def get_all_ex(call: APICall, company_id, request: ModelsGetRequest):
models = Model.get_many_with_join( models = Model.get_many_with_join(
company=company_id, company=company_id,
query_dict=call.data, query_dict=call.data,
allow_public=True, allow_public=request.allow_public,
ret_params=ret_params, ret_params=ret_params,
) )
conform_output_tags(call, models) conform_output_tags(call, models)
@ -482,6 +482,7 @@ def set_ready(call: APICall, company_id: str, request: PublishModelRequest):
updated, published_task = ModelBLL.publish_model( updated, published_task = ModelBLL.publish_model(
model_id=request.model, model_id=request.model,
company_id=company_id, company_id=company_id,
user_id=call.identity.user,
force_publish_task=request.force_publish_task, force_publish_task=request.force_publish_task,
publish_task_func=publish_task if request.publish_task else None, publish_task_func=publish_task if request.publish_task else None,
) )
@ -500,6 +501,7 @@ def publish_many(call: APICall, company_id, request: ModelsPublishManyRequest):
func=partial( func=partial(
ModelBLL.publish_model, ModelBLL.publish_model,
company_id=company_id, company_id=company_id,
user_id=call.identity.user,
force_publish_task=request.force_publish_task, force_publish_task=request.force_publish_task,
publish_task_func=publish_task if request.publish_task else None, publish_task_func=publish_task if request.publish_task else None,
), ),

View File

@ -10,7 +10,7 @@ from apiserver.bll.project import ProjectBLL
from apiserver.database.model import User, AttributedDocument, EntityVisibility from apiserver.database.model import User, AttributedDocument, EntityVisibility
from apiserver.database.model.model import Model from apiserver.database.model.model import Model
from apiserver.database.model.project import Project from apiserver.database.model.project import Project
from apiserver.database.model.task.task import Task from apiserver.database.model.task.task import Task, TaskType
from apiserver.service_repo import endpoint, APICall from apiserver.service_repo import endpoint, APICall
from apiserver.services.utils import get_tags_filter_dictionary, sort_tags_response from apiserver.services.utils import get_tags_filter_dictionary, sort_tags_response
@ -59,6 +59,7 @@ def get_entities_count(call: APICall, company, request: EntitiesCountRequest):
"models": Model, "models": Model,
"pipelines": Project, "pipelines": Project,
"datasets": Project, "datasets": Project,
"reports": Task,
} }
ret = {} ret = {}
for field, entity_cls in entity_classes.items(): for field, entity_cls in entity_classes.items():
@ -66,6 +67,10 @@ def get_entities_count(call: APICall, company, request: EntitiesCountRequest):
if data is None: if data is None:
continue continue
if field == "reports":
data["type"] = TaskType.report
data["include_subprojects"] = True
if request.active_users: if request.active_users:
if entity_cls is Project: if entity_cls is Project:
requested_ids = data.get("id") requested_ids = data.get("id")
@ -75,7 +80,7 @@ def get_entities_count(call: APICall, company, request: EntitiesCountRequest):
company=company, company=company,
users=request.active_users, users=request.active_users,
project_ids=requested_ids, project_ids=requested_ids,
allow_public=True, allow_public=request.allow_public,
) )
if not ids: if not ids:
ret[field] = 0 ret[field] = 0
@ -85,11 +90,18 @@ def get_entities_count(call: APICall, company, request: EntitiesCountRequest):
data["user"] = request.active_users data["user"] = request.active_users
query = Q() query = Q()
if entity_cls in (Project, Task) and not request.search_hidden: if (
entity_cls in (Project, Task)
and field != "reports"
and not request.search_hidden
):
query &= Q(system_tags__ne=EntityVisibility.hidden.value) query &= Q(system_tags__ne=EntityVisibility.hidden.value)
ret[field] = entity_cls.get_count( ret[field] = entity_cls.get_count(
company=company, query_dict=data, query=query, allow_public=True, company=company,
query_dict=data,
query=query,
allow_public=request.allow_public,
) )
call.result.data = ret call.result.data = ret

View File

@ -100,7 +100,14 @@ def _adjust_search_parameters(data: dict, shallow_search: bool):
def get_all_ex(call: APICall, company_id: str, request: ProjectsGetRequest): def get_all_ex(call: APICall, company_id: str, request: ProjectsGetRequest):
data = call.data data = call.data
conform_tag_fields(call, data) conform_tag_fields(call, data)
allow_public = not request.non_public allow_public = (
data["allow_public"]
if "allow_public" in data
else not data["non_public"]
if "non_public" in data
else request.allow_public
)
requested_ids = data.get("id") requested_ids = data.get("id")
if isinstance(requested_ids, str): if isinstance(requested_ids, str):
requested_ids = [requested_ids] requested_ids = [requested_ids]

View File

@ -142,7 +142,10 @@ def update(call: APICall, company_id, req_model: UpdateRequest):
@endpoint("queues.delete", min_version="2.4", request_data_model=DeleteRequest) @endpoint("queues.delete", min_version="2.4", request_data_model=DeleteRequest)
def delete(call: APICall, company_id, req_model: DeleteRequest): def delete(call: APICall, company_id, req_model: DeleteRequest):
queue_bll.delete( queue_bll.delete(
company_id=company_id, queue_id=req_model.queue, force=req_model.force company_id=company_id,
user_id=call.identity.user,
queue_id=req_model.queue,
force=req_model.force,
) )
call.result.data = {"deleted": 1} call.result.data = {"deleted": 1}

View File

@ -51,9 +51,7 @@ update_fields = {
} }
def _assert_report( def _assert_report(company_id, task_id, only_fields=None, requires_write_access=True):
company_id, task_id, only_fields=None, requires_write_access=True
):
if only_fields and "type" not in only_fields: if only_fields and "type" not in only_fields:
only_fields += ("type",) only_fields += ("type",)
@ -72,9 +70,7 @@ def _assert_report(
@endpoint("reports.update", response_data_model=UpdateResponse) @endpoint("reports.update", response_data_model=UpdateResponse)
def update_report(call: APICall, company_id: str, request: UpdateReportRequest): def update_report(call: APICall, company_id: str, request: UpdateReportRequest):
task = _assert_report( task = _assert_report(
task_id=request.task, task_id=request.task, company_id=company_id, only_fields=("status",),
company_id=company_id,
only_fields=("status",),
) )
if task.status != TaskStatus.created: if task.status != TaskStatus.created:
raise errors.bad_request.InvalidTaskStatus( raise errors.bad_request.InvalidTaskStatus(
@ -196,7 +192,7 @@ def _get_task_metrics_from_request(
return task_metrics return task_metrics
@endpoint("reports.get_task_data", required_fields=[]) @endpoint("reports.get_task_data")
def get_task_data(call: APICall, company_id, request: GetTasksDataRequest): def get_task_data(call: APICall, company_id, request: GetTasksDataRequest):
call_data = escape_execution_parameters(call) call_data = escape_execution_parameters(call)
process_include_subprojects(call_data) process_include_subprojects(call_data)
@ -212,16 +208,12 @@ def get_task_data(call: APICall, company_id, request: GetTasksDataRequest):
unprepare_from_saved(call, tasks) unprepare_from_saved(call, tasks)
res = {"tasks": tasks, **ret_params} res = {"tasks": tasks, **ret_params}
if not ( if not (
request.debug_images request.debug_images or request.plots or request.scalar_metrics_iter_histogram
or request.plots
or request.scalar_metrics_iter_histogram
): ):
return res return res
task_ids = [task["id"] for task in tasks] task_ids = [task["id"] for task in tasks]
company, tasks_or_models = _get_task_or_model_index_company( company, tasks_or_models = _get_task_or_model_index_company(company_id, task_ids)
company_id, task_ids
)
if request.debug_images: if request.debug_images:
result = event_bll.debug_images_iterator.get_task_events( result = event_bll.debug_images_iterator.get_task_events(
company_id=company, company_id=company,
@ -264,9 +256,7 @@ def move(call: APICall, company_id: str, request: MoveReportRequest):
) )
task = _assert_report( task = _assert_report(
company_id=company_id, company_id=company_id, task_id=request.task, only_fields=("project",),
task_id=request.task,
only_fields=("project",),
) )
user_id = call.identity.user user_id = call.identity.user
project_name = request.project_name project_name = request.project_name
@ -297,12 +287,9 @@ def move(call: APICall, company_id: str, request: MoveReportRequest):
"reports.publish", response_data_model=UpdateResponse, "reports.publish", response_data_model=UpdateResponse,
) )
def publish(call: APICall, company_id, request: PublishReportRequest): def publish(call: APICall, company_id, request: PublishReportRequest):
task = _assert_report( task = _assert_report(company_id=company_id, task_id=request.task)
company_id=company_id, task_id=request.task
)
updates = ChangeStatusRequest( updates = ChangeStatusRequest(
task=task, task=task,
company=company_id,
new_status=TaskStatus.published, new_status=TaskStatus.published,
force=True, force=True,
status_reason="", status_reason="",
@ -315,9 +302,7 @@ def publish(call: APICall, company_id, request: PublishReportRequest):
@endpoint("reports.archive") @endpoint("reports.archive")
def archive(call: APICall, company_id, request: ArchiveReportRequest): def archive(call: APICall, company_id, request: ArchiveReportRequest):
task = _assert_report( task = _assert_report(company_id=company_id, task_id=request.task)
company_id=company_id, task_id=request.task
)
archived = task.update( archived = task.update(
status_message=request.message, status_message=request.message,
status_reason="", status_reason="",
@ -331,9 +316,7 @@ def archive(call: APICall, company_id, request: ArchiveReportRequest):
@endpoint("reports.unarchive") @endpoint("reports.unarchive")
def unarchive(call: APICall, company_id, request: ArchiveReportRequest): def unarchive(call: APICall, company_id, request: ArchiveReportRequest):
task = _assert_report( task = _assert_report(company_id=company_id, task_id=request.task)
company_id=company_id, task_id=request.task
)
unarchived = task.update( unarchived = task.update(
status_message=request.message, status_message=request.message,
status_reason="", status_reason="",
@ -359,9 +342,7 @@ def unarchive(call: APICall, company_id, request: ArchiveReportRequest):
@endpoint("reports.delete") @endpoint("reports.delete")
def delete(call: APICall, company_id, request: DeleteReportRequest): def delete(call: APICall, company_id, request: DeleteReportRequest):
task = _assert_report( task = _assert_report(
company_id=company_id, company_id=company_id, task_id=request.task, only_fields=("project",),
task_id=request.task,
only_fields=("project",),
) )
if ( if (
task.status != TaskStatus.created task.status != TaskStatus.created

View File

@ -64,6 +64,7 @@ from apiserver.apimodels.tasks import (
ResetBatchItem, ResetBatchItem,
CompletedRequest, CompletedRequest,
CompletedResponse, CompletedResponse,
GetAllReq,
) )
from apiserver.bll.event import EventBLL from apiserver.bll.event import EventBLL
from apiserver.bll.model import ModelBLL from apiserver.bll.model import ModelBLL
@ -136,7 +137,7 @@ project_bll = ProjectBLL()
def set_task_status_from_call( def set_task_status_from_call(
request: UpdateRequest, company_id, new_status=None, **set_fields request: UpdateRequest, company_id: str, user_id: str, new_status=None, **set_fields
) -> dict: ) -> dict:
fields_resolver = SetFieldsResolver(set_fields) fields_resolver = SetFieldsResolver(set_fields)
task = TaskBLL.get_task_with_access( task = TaskBLL.get_task_with_access(
@ -171,6 +172,7 @@ def set_task_status_from_call(
status_reason=status_reason, status_reason=status_reason,
status_message=status_message, status_message=status_message,
force=force, force=force,
user_id=user_id,
).execute(**fields_resolver.get_fields(task)) ).execute(**fields_resolver.get_fields(task))
@ -214,8 +216,8 @@ def _hidden_query(data: dict) -> Q:
return Q(system_tags__ne=EntityVisibility.hidden.value) return Q(system_tags__ne=EntityVisibility.hidden.value)
@endpoint("tasks.get_all_ex", required_fields=[]) @endpoint("tasks.get_all_ex")
def get_all_ex(call: APICall, company_id, _): def get_all_ex(call: APICall, company_id, request: GetAllReq):
conform_tag_fields(call, call.data) conform_tag_fields(call, call.data)
call_data = escape_execution_parameters(call) call_data = escape_execution_parameters(call)
@ -226,7 +228,7 @@ def get_all_ex(call: APICall, company_id, _):
company=company_id, company=company_id,
query_dict=call_data, query_dict=call_data,
query=_hidden_query(call_data), query=_hidden_query(call_data),
allow_public=True, allow_public=request.allow_public,
ret_params=ret_params, ret_params=ret_params,
) )
unprepare_from_saved(call, tasks) unprepare_from_saved(call, tasks)
@ -291,6 +293,7 @@ def stop(call: APICall, company_id, req_model: UpdateRequest):
**stop_task( **stop_task(
task_id=req_model.task, task_id=req_model.task,
company_id=company_id, company_id=company_id,
user_id=call.identity.user,
user_name=call.identity.user_name, user_name=call.identity.user_name,
status_reason=req_model.status_reason, status_reason=req_model.status_reason,
force=req_model.force, force=req_model.force,
@ -308,6 +311,7 @@ def stop_many(call: APICall, company_id, request: StopManyRequest):
func=partial( func=partial(
stop_task, stop_task,
company_id=company_id, company_id=company_id,
user_id=call.identity.user,
user_name=call.identity.user_name, user_name=call.identity.user_name,
status_reason=request.status_reason, status_reason=request.status_reason,
force=request.force, force=request.force,
@ -329,7 +333,8 @@ def stopped(call: APICall, company_id, req_model: UpdateRequest):
call.result.data_model = UpdateResponse( call.result.data_model = UpdateResponse(
**set_task_status_from_call( **set_task_status_from_call(
req_model, req_model,
company_id, company_id=company_id,
user_id=call.identity.user,
new_status=TaskStatus.stopped, new_status=TaskStatus.stopped,
completed=datetime.utcnow(), completed=datetime.utcnow(),
) )
@ -345,7 +350,8 @@ def started(call: APICall, company_id, req_model: UpdateRequest):
res = StartedResponse( res = StartedResponse(
**set_task_status_from_call( **set_task_status_from_call(
req_model, req_model,
company_id, company_id=company_id,
user_id=call.identity.user,
new_status=TaskStatus.in_progress, new_status=TaskStatus.in_progress,
min__started=datetime.utcnow(), # don't override a previous, smaller "started" field value min__started=datetime.utcnow(), # don't override a previous, smaller "started" field value
) )
@ -359,7 +365,12 @@ def started(call: APICall, company_id, req_model: UpdateRequest):
) )
def failed(call: APICall, company_id, req_model: UpdateRequest): def failed(call: APICall, company_id, req_model: UpdateRequest):
call.result.data_model = UpdateResponse( call.result.data_model = UpdateResponse(
**set_task_status_from_call(req_model, company_id, new_status=TaskStatus.failed) **set_task_status_from_call(
req_model,
company_id=company_id,
user_id=call.identity.user,
new_status=TaskStatus.failed,
)
) )
@ -368,7 +379,11 @@ def failed(call: APICall, company_id, req_model: UpdateRequest):
) )
def close(call: APICall, company_id, req_model: UpdateRequest): def close(call: APICall, company_id, req_model: UpdateRequest):
call.result.data_model = UpdateResponse( call.result.data_model = UpdateResponse(
**set_task_status_from_call(req_model, company_id, new_status=TaskStatus.closed) **set_task_status_from_call(
req_model,
company_id=company_id,
user_id=call.identity.user,
new_status=TaskStatus.closed)
) )
@ -580,7 +595,9 @@ def update(call: APICall, company_id, req_model: UpdateRequest):
company_id=company_id, company_id=company_id,
id=task_id, id=task_id,
partial_update_dict=partial_update_dict, partial_update_dict=partial_update_dict,
injected_update=dict(last_change=datetime.utcnow()), injected_update=dict(
last_change=datetime.utcnow(), last_changed_by=call.identity.user,
),
) )
if updated_count: if updated_count:
new_project = updated_fields.get("project", task.project) new_project = updated_fields.get("project", task.project)
@ -613,7 +630,11 @@ def set_requirements(call: APICall, company_id, req_model: SetRequirementsReques
raise errors.bad_request.MissingTaskFields( raise errors.bad_request.MissingTaskFields(
"Task has no script field", task=task.id "Task has no script field", task=task.id
) )
res = update_task(task, update_cmds=dict(script__requirements=requirements)) res = update_task(
task,
user_id=call.identity.user,
update_cmds=dict(script__requirements=requirements),
)
call.result.data_model = UpdateResponse(updated=res) call.result.data_model = UpdateResponse(updated=res)
if res: if res:
call.result.data_model.fields = {"script.requirements": requirements} call.result.data_model.fields = {"script.requirements": requirements}
@ -648,7 +669,9 @@ def update_batch(call: APICall, company_id, _):
partial_update_dict = Task.get_safe_update_dict(fields) partial_update_dict = Task.get_safe_update_dict(fields)
if not partial_update_dict: if not partial_update_dict:
continue continue
partial_update_dict.update(last_change=now) partial_update_dict.update(
last_change=now, last_changed_by=call.identity.user,
)
update_op = UpdateOne( update_op = UpdateOne(
{"_id": id, "company": company_id}, {"$set": partial_update_dict} {"_id": id, "company": company_id}, {"$set": partial_update_dict}
) )
@ -725,7 +748,7 @@ def edit(call: APICall, company_id, req_model: UpdateRequest):
} }
if fixed_fields: if fixed_fields:
now = datetime.utcnow() now = datetime.utcnow()
last_change = dict(last_change=now) last_change = dict(last_change=now, last_changed_by=call.identity.user)
if not set(fields).issubset(Task.user_set_allowed()): if not set(fields).issubset(Task.user_set_allowed()):
last_change.update(last_update=now) last_change.update(last_update=now)
fields.update(**last_change) fields.update(**last_change)
@ -762,6 +785,7 @@ def edit_hyper_params(call: APICall, company_id, request: EditHyperParamsRequest
call.result.data = { call.result.data = {
"updated": HyperParams.edit_params( "updated": HyperParams.edit_params(
company_id, company_id,
user_id=call.identity.user,
task_id=request.task, task_id=request.task,
hyperparams=request.hyperparams, hyperparams=request.hyperparams,
replace_hyperparams=request.replace_hyperparams, replace_hyperparams=request.replace_hyperparams,
@ -775,6 +799,7 @@ def delete_hyper_params(call: APICall, company_id, request: DeleteHyperParamsReq
call.result.data = { call.result.data = {
"deleted": HyperParams.delete_params( "deleted": HyperParams.delete_params(
company_id, company_id,
user_id=call.identity.user,
task_id=request.task, task_id=request.task,
hyperparams=request.hyperparams, hyperparams=request.hyperparams,
force=request.force, force=request.force,
@ -819,6 +844,7 @@ def edit_configuration(call: APICall, company_id, request: EditConfigurationRequ
call.result.data = { call.result.data = {
"updated": HyperParams.edit_configuration( "updated": HyperParams.edit_configuration(
company_id, company_id,
user_id=call.identity.user,
task_id=request.task, task_id=request.task,
configuration=request.configuration, configuration=request.configuration,
replace_configuration=request.replace_configuration, replace_configuration=request.replace_configuration,
@ -834,6 +860,7 @@ def delete_configuration(
call.result.data = { call.result.data = {
"deleted": HyperParams.delete_configuration( "deleted": HyperParams.delete_configuration(
company_id, company_id,
user_id=call.identity.user,
task_id=request.task, task_id=request.task,
configuration=request.configuration, configuration=request.configuration,
force=request.force, force=request.force,
@ -850,6 +877,7 @@ def enqueue(call: APICall, company_id, request: EnqueueRequest):
queued, res = enqueue_task( queued, res = enqueue_task(
task_id=request.task, task_id=request.task,
company_id=company_id, company_id=company_id,
user_id=call.identity.user,
queue_id=request.queue, queue_id=request.queue,
status_message=request.status_message, status_message=request.status_message,
status_reason=request.status_reason, status_reason=request.status_reason,
@ -874,6 +902,7 @@ def enqueue_many(call: APICall, company_id, request: EnqueueManyRequest):
func=partial( func=partial(
enqueue_task, enqueue_task,
company_id=company_id, company_id=company_id,
user_id=call.identity.user,
queue_id=request.queue, queue_id=request.queue,
status_message=request.status_message, status_message=request.status_message,
status_reason=request.status_reason, status_reason=request.status_reason,
@ -908,6 +937,7 @@ def dequeue(call: APICall, company_id, request: UpdateRequest):
dequeued, res = dequeue_task( dequeued, res = dequeue_task(
task_id=request.task, task_id=request.task,
company_id=company_id, company_id=company_id,
user_id=call.identity.user,
status_message=request.status_message, status_message=request.status_message,
status_reason=request.status_reason, status_reason=request.status_reason,
) )
@ -924,6 +954,7 @@ def dequeue_many(call: APICall, company_id, request: TaskBatchRequest):
func=partial( func=partial(
dequeue_task, dequeue_task,
company_id=company_id, company_id=company_id,
user_id=call.identity.user,
status_message=request.status_message, status_message=request.status_message,
status_reason=request.status_reason, status_reason=request.status_reason,
), ),
@ -1019,6 +1050,7 @@ def archive(call: APICall, company_id, request: ArchiveRequest):
for task in tasks: for task in tasks:
archived += archive_task( archived += archive_task(
company_id=company_id, company_id=company_id,
user_id=call.identity.user,
task=task, task=task,
status_message=request.status_message, status_message=request.status_message,
status_reason=request.status_reason, status_reason=request.status_reason,
@ -1037,6 +1069,7 @@ def archive_many(call: APICall, company_id, request: TaskBatchRequest):
func=partial( func=partial(
archive_task, archive_task,
company_id=company_id, company_id=company_id,
user_id=call.identity.user,
status_message=request.status_message, status_message=request.status_message,
status_reason=request.status_reason, status_reason=request.status_reason,
), ),
@ -1058,6 +1091,7 @@ def unarchive_many(call: APICall, company_id, request: TaskBatchRequest):
func=partial( func=partial(
unarchive_task, unarchive_task,
company_id=company_id, company_id=company_id,
user_id=call.identity.user,
status_message=request.status_message, status_message=request.status_message,
status_reason=request.status_reason, status_reason=request.status_reason,
), ),
@ -1136,6 +1170,7 @@ def publish(call: APICall, company_id, request: PublishRequest):
updates = publish_task( updates = publish_task(
task_id=request.task, task_id=request.task,
company_id=company_id, company_id=company_id,
user_id=call.identity.user,
force=request.force, force=request.force,
publish_model_func=ModelBLL.publish_model if request.publish_model else None, publish_model_func=ModelBLL.publish_model if request.publish_model else None,
status_reason=request.status_reason, status_reason=request.status_reason,
@ -1154,6 +1189,7 @@ def publish_many(call: APICall, company_id, request: PublishManyRequest):
func=partial( func=partial(
publish_task, publish_task,
company_id=company_id, company_id=company_id,
user_id=call.identity.user,
force=request.force, force=request.force,
publish_model_func=ModelBLL.publish_model publish_model_func=ModelBLL.publish_model
if request.publish_model if request.publish_model
@ -1180,7 +1216,8 @@ def completed(call: APICall, company_id, request: CompletedRequest):
res = CompletedResponse( res = CompletedResponse(
**set_task_status_from_call( **set_task_status_from_call(
request, request,
company_id, company_id=company_id,
user_id=call.identity.user,
new_status=TaskStatus.completed, new_status=TaskStatus.completed,
completed=datetime.utcnow(), completed=datetime.utcnow(),
) )
@ -1190,6 +1227,7 @@ def completed(call: APICall, company_id, request: CompletedRequest):
publish_res = publish_task( publish_res = publish_task(
task_id=request.task, task_id=request.task,
company_id=company_id, company_id=company_id,
user_id=call.identity.user,
force=request.force, force=request.force,
publish_model_func=ModelBLL.publish_model, publish_model_func=ModelBLL.publish_model,
status_reason=request.status_reason, status_reason=request.status_reason,
@ -1221,6 +1259,7 @@ def add_or_update_artifacts(
call.result.data = { call.result.data = {
"updated": Artifacts.add_or_update_artifacts( "updated": Artifacts.add_or_update_artifacts(
company_id=company_id, company_id=company_id,
user_id=call.identity.user,
task_id=request.task, task_id=request.task,
artifacts=request.artifacts, artifacts=request.artifacts,
force=True, force=True,
@ -1237,6 +1276,7 @@ def delete_artifacts(call: APICall, company_id, request: DeleteArtifactsRequest)
call.result.data = { call.result.data = {
"deleted": Artifacts.delete_artifacts( "deleted": Artifacts.delete_artifacts(
company_id=company_id, company_id=company_id,
user_id=call.identity.user,
task_id=request.task, task_id=request.task,
artifact_ids=request.artifacts, artifact_ids=request.artifacts,
force=True, force=True,
@ -1304,7 +1344,7 @@ def add_or_update_model(_: APICall, company_id: str, request: AddUpdateModelRequ
@endpoint("tasks.delete_models", min_version="2.13") @endpoint("tasks.delete_models", min_version="2.13")
def delete_models(_: APICall, company_id: str, request: DeleteModelsRequest): def delete_models(call: APICall, company_id: str, request: DeleteModelsRequest):
task = get_task_for_update(company_id=company_id, task_id=request.task, force=True) task = get_task_for_update(company_id=company_id, task_id=request.task, force=True)
delete_names = { delete_names = {
@ -1317,5 +1357,5 @@ def delete_models(_: APICall, company_id: str, request: DeleteModelsRequest):
if names if names
} }
updated = task.update(last_change=datetime.utcnow(), **commands,) updated = task.update(last_change=datetime.utcnow(), last_changed_by=call.identity.user, **commands,)
return {"updated": updated} return {"updated": updated}