diff --git a/server/database/model/model.py b/server/database/model/model.py index bb42e80..7aa7219 100644 --- a/server/database/model/model.py +++ b/server/database/model/model.py @@ -12,35 +12,32 @@ from database.model.user import User class Model(DbModelMixin, Document): meta = { - 'db_alias': Database.backend, - 'strict': strict, - 'indexes': [ + "db_alias": Database.backend, + "strict": strict, + "indexes": [ + "parent", + "project", + "task", + ("company", "name"), { - 'name': '%s.model.main_text_index' % Database.backend, - 'fields': [ - '$name', - '$id', - '$comment', - '$parent', - '$task', - '$project', - ], - 'default_language': 'english', - 'weights': { - 'name': 10, - 'id': 10, - 'comment': 10, - 'parent': 5, - 'task': 3, - 'project': 3, - } - } + "name": "%s.model.main_text_index" % Database.backend, + "fields": ["$name", "$id", "$comment", "$parent", "$task", "$project"], + "default_language": "english", + "weights": { + "name": 10, + "id": 10, + "comment": 10, + "parent": 5, + "task": 3, + "project": 3, + }, + }, ], } id = StringField(primary_key=True) name = StrippedStringField(user_set_allowed=True, min_length=3) - parent = StringField(reference_field='Model', required=False) + parent = StringField(reference_field="Model", required=False) user = StringField(required=True, reference_field=User) company = StringField(required=True, reference_field=Company) project = StringField(reference_field=Project, user_set_allowed=True) @@ -49,9 +46,11 @@ class Model(DbModelMixin, Document): comment = StringField(user_set_allowed=True) tags = ListField(StringField(required=True), user_set_allowed=True) system_tags = ListField(StringField(required=True), user_set_allowed=True) - uri = StrippedStringField(default='', user_set_allowed=True) + uri = StrippedStringField(default="", user_set_allowed=True) framework = StringField() design = SafeDictField() labels = ModelLabels() ready = BooleanField(required=True) - ui_cache = SafeDictField(default=dict, user_set_allowed=True, exclude_by_default=True) + ui_cache = SafeDictField( + default=dict, user_set_allowed=True, exclude_by_default=True + ) diff --git a/server/database/model/project.py b/server/database/model/project.py index 2accb3b..d961440 100644 --- a/server/database/model/project.py +++ b/server/database/model/project.py @@ -17,12 +17,13 @@ class Project(AttributedDocument): "db_alias": Database.backend, "strict": strict, "indexes": [ + ("company", "name"), { "name": "%s.project.main_text_index" % Database.backend, "fields": ["$name", "$id", "$description"], "default_language": "english", "weights": {"name": 10, "id": 10, "description": 10}, - } + }, ], } diff --git a/server/database/model/task/task.py b/server/database/model/task/task.py index 0b55aa5..b159233 100644 --- a/server/database/model/task/task.py +++ b/server/database/model/task/task.py @@ -110,6 +110,12 @@ class Task(AttributedDocument): "created", "started", "completed", + "parent", + "project", + ("company", "name"), + ("company", "type", "system_tags", "status"), + ("company", "project", "type", "system_tags", "status"), + ("status", "last_update"), # for maintenance tasks { "name": "%s.task.main_text_index" % Database.backend, "fields": [ diff --git a/server/services/projects.py b/server/services/projects.py index c7dc548..ff98102 100644 --- a/server/services/projects.py +++ b/server/services/projects.py @@ -33,8 +33,7 @@ create_fields = { } get_all_query_options = Project.QueryParameterOptions( - pattern_fields=("name", "description"), - list_fields=("tags", "system_tags", "id"), + pattern_fields=("name", "description"), list_fields=("tags", "system_tags", "id"), ) @@ -58,7 +57,7 @@ def get_by_id(call): call.result.data = {"project": project_dict} -def make_projects_get_all_pipelines(project_ids, specific_state=None): +def make_projects_get_all_pipelines(company_id, project_ids, specific_state=None): archived = EntityVisibility.archived.value def ensure_valid_fields(): @@ -74,15 +73,18 @@ def make_projects_get_all_pipelines(project_ids, specific_state=None): "else": "$system_tags", } }, - "status": { - "$ifNull": ["$status", "unknown"] - } + "status": {"$ifNull": ["$status", "unknown"]}, } } status_count_pipeline = [ # count tasks per project per status - {"$match": {"project": {"$in": project_ids}}}, + { + "$match": { + "company": {"$in": [None, "", company_id]}, + "project": {"$in": project_ids}, + } + }, ensure_valid_fields(), { "$group": { @@ -153,7 +155,10 @@ def make_projects_get_all_pipelines(project_ids, specific_state=None): { "$match": { "type": {"$in": ["training", "testing", "annotation"]}, - "project": {"$in": project_ids}, + "project": { + "company": {"$in": [None, "", company_id]}, + "$in": project_ids, + }, } }, ensure_valid_fields(), @@ -195,7 +200,7 @@ def get_all_ex(call: APICall): ids = [project["id"] for project in projects] status_count_pipeline, runtime_pipeline = make_projects_get_all_pipelines( - ids, specific_state=specific_state + call.identity.company, ids, specific_state=specific_state ) default_counts = dict.fromkeys(get_options(TaskStatus), 0) @@ -205,7 +210,7 @@ def get_all_ex(call: APICall): status_count = defaultdict(lambda: {}) key = itemgetter(EntityVisibility.archived.value) - for result in Task.aggregate(*status_count_pipeline): + for result in Task.aggregate(status_count_pipeline): for k, group in groupby(sorted(result["counts"], key=key), key): section = ( EntityVisibility.archived if k else EntityVisibility.active @@ -219,7 +224,7 @@ def get_all_ex(call: APICall): runtime = { result["_id"]: {k: v for k, v in result.items() if k != "_id"} - for result in Task.aggregate(*runtime_pipeline) + for result in Task.aggregate(runtime_pipeline) } def safe_get(obj, path, default=None): diff --git a/server/services/tasks.py b/server/services/tasks.py index 2a9984d..a2bef35 100644 --- a/server/services/tasks.py +++ b/server/services/tasks.py @@ -750,8 +750,7 @@ class CleanupResult(object): deleted_models = attr.ib(type=int) -def cleanup_task(task, force=False): - # type: (Task, bool) -> CleanupResult +def cleanup_task(task: Task, force: bool = False): """ Validate task deletion and delete/modify all its output. :param task: task object