diff --git a/server/apimodels/models.py b/server/apimodels/models.py index 975f569..14f6c9e 100644 --- a/server/apimodels/models.py +++ b/server/apimodels/models.py @@ -6,6 +6,10 @@ from apimodels.base import UpdateResponse from apimodels.tasks import PublishResponse as TaskPublishResponse +class GetFrameworksRequest(models.Base): + projects = fields.ListField(items_types=[str]) + + class CreateModelRequest(models.Base): name = fields.StringField(required=True) uri = fields.StringField(required=True) diff --git a/server/database/model/model.py b/server/database/model/model.py index d7b341b..b777efd 100644 --- a/server/database/model/model.py +++ b/server/database/model/model.py @@ -19,6 +19,7 @@ class Model(DbModelMixin, Document): "parent", "project", "task", + ("company", "framework"), ("company", "name"), ("company", "user"), { diff --git a/server/schema/services/models.conf b/server/schema/services/models.conf index 2ed2db8..1e23c2b 100644 --- a/server/schema/services/models.conf +++ b/server/schema/services/models.conf @@ -252,6 +252,31 @@ } } } + get_frameworks { + "2.8" { + description: "Get the list of frameworks used in the company models" + request { + type: object + properties { + projects { + description: "The list of projects which models will be analyzed. If not passed or empty then all the company and public models will be analyzed" + type: array + items: {type: string} + } + } + } + response { + type: object + properties { + frameworks { + description: "Unique list of the frameworks used in the company models" + type: array + items: {type: string} + } + } + } + } + } update_for_task { "2.1" { description: "Create or update a new model for a task" diff --git a/server/services/models.py b/server/services/models.py index 19bc3a1..1a94762 100644 --- a/server/services/models.py +++ b/server/services/models.py @@ -12,7 +12,9 @@ from apimodels.models import ( PublishModelRequest, PublishModelResponse, ModelTaskPublishResponse, + GetFrameworksRequest, ) +from bll.model import ModelBLL from bll.organization import OrgBLL, Tags from bll.task import TaskBLL from config import config @@ -32,6 +34,7 @@ from timing_context import TimingContext log = config.logger(__file__) org_bll = OrgBLL() +model_bll = ModelBLL() @endpoint("models.get_by_id", required_fields=["model"]) @@ -107,6 +110,15 @@ def get_all(call: APICall, company_id, _): call.result.data = {"models": models} +@endpoint("models.get_frameworks", request_data_model=GetFrameworksRequest) +def get_frameworks(call: APICall, company_id, request: GetFrameworksRequest): + call.result.data = { + "frameworks": sorted( + model_bll.get_frameworks(company_id, project_ids=request.projects) + ) + } + + create_fields = { "name": None, "tags": list, diff --git a/server/tests/automated/test_models.py b/server/tests/automated/test_models.py index e7aef2c..e6d9747 100644 --- a/server/tests/automated/test_models.py +++ b/server/tests/automated/test_models.py @@ -7,6 +7,9 @@ IN_PROGRESS = "in_progress" class TestModelsService(TestService): + def setUp(self, version="2.8"): + super().setUp(version=version) + def test_publish_output_model_running_task(self): task_id, model_id = self._create_task_and_model() self._assert_model_ready(model_id, False) @@ -164,6 +167,36 @@ class TestModelsService(TestService): 1000 ) + def test_get_frameworks(self): + framework_1 = "Test framework 1" + framework_2 = "Test framework 2" + + # create model on top level + self._create_model(name="framework model test", framework=framework_1) + + # create model under a project as make it inherit its framework from the task + project = self.create_temp("projects", name="Frameworks test", description="") + task = self._create_task(project=project, execution=dict(framework=framework_2)) + self.api.models.update_for_task( + task=task, + name="framework output model test", + uri="file:///b", + iteration=999, + ) + + # get all frameworks + res = self.api.models.get_frameworks() + self.assertTrue({framework_1, framework_2}.issubset(set(res.frameworks))) + + # get frameworks under the project + res = self.api.models.get_frameworks(projects=[project]) + self.assertEqual([framework_2], res.frameworks) + + # empty result + self.api.tasks.delete(task=task, force=True) + res = self.api.models.get_frameworks(projects=[project]) + self.assertEqual([], res.frameworks) + def _assert_task_status(self, task_id, status): task = self.api.tasks.get_by_id(task=task_id).task assert task.status == status @@ -178,24 +211,23 @@ class TestModelsService(TestService): def _assert_update_task_failure(self): return self.api.raises(TASK_CANNOT_BE_UPDATED_CODES) - def _create_model(self): - model_id = self.create_temp( + def _create_model(self, **kwargs): + return self.create_temp( service="models", - name='test', - uri='file:///a', - labels={} + delete_params=dict(can_fail=True, force=True), + name=kwargs.pop("name", 'test'), + uri=kwargs.pop("name", 'file:///a'), + labels=kwargs.pop("labels", {}), + **kwargs, ) - self.defer(self.api.models.delete, can_fail=True, model=model_id, force=True) - - return model_id - - def _create_task(self): + def _create_task(self, **kwargs): task_id = self.create_temp( service="tasks", - type='testing', - name='server-test', - input=dict(view={}), + type=kwargs.pop("type", 'testing'), + name=kwargs.pop("name", 'server-test'), + input=kwargs.pop("input", dict(view={})), + **kwargs, ) return task_id