From a47e65d9743cbea0a452c9b37e085547cbae96a0 Mon Sep 17 00:00:00 2001 From: allegroai <> Date: Tue, 13 Feb 2024 16:15:55 +0200 Subject: [PATCH] Add input parameters check to multiple APIs --- apiserver/apimodels/events.py | 46 +++++++- apiserver/apimodels/models.py | 15 +++ apiserver/apimodels/server.py | 4 + apiserver/apimodels/users.py | 4 + apiserver/services/events.py | 152 ++++++++++++++------------ apiserver/services/models.py | 56 +++++----- apiserver/services/projects.py | 46 ++++---- apiserver/services/server/__init__.py | 10 +- apiserver/services/users.py | 30 ++--- 9 files changed, 225 insertions(+), 138 deletions(-) diff --git a/apiserver/apimodels/events.py b/apiserver/apimodels/events.py index 72bd3d7..6ddc29f 100644 --- a/apiserver/apimodels/events.py +++ b/apiserver/apimodels/events.py @@ -13,6 +13,14 @@ from apiserver.config_repo import config from apiserver.utilities.stringenum import StringEnum +class TaskRequest(Base): + task: str = StringField(required=True) + + +class ModelRequest(Base): + model: str = StringField(required=True) + + class HistogramRequestBase(Base): samples: int = IntField(default=2000, validators=[Min(1), Max(6000)]) key: ScalarKeyEnum = ActualEnumField(ScalarKeyEnum, default=ScalarKeyEnum.iter) @@ -29,6 +37,11 @@ class ScalarMetricsIterHistogramRequest(HistogramRequestBase): model_events: bool = BoolField(default=False) +class GetMetricsAndVariantsRequest(Base): + task: str = StringField(required=True) + model_events: bool = BoolField(default=False) + + class MultiTaskScalarMetricsIterHistogramRequest(HistogramRequestBase): tasks: Sequence[str] = ListField( items_types=str, @@ -51,6 +64,12 @@ class TaskMetric(Base): variants: Sequence[str] = ListField(items_types=str) +class LegacyMetricEventsRequest(TaskRequest): + iters: int = IntField(default=1, validators=validators.Min(1)) + scroll_id: str = StringField() + model_events: bool = BoolField(default=False) + + class MetricEventsRequest(Base): metrics: Sequence[TaskMetric] = ListField( items_types=TaskMetric, validators=[Length(minimum_value=1)] @@ -59,7 +78,14 @@ class MetricEventsRequest(Base): navigate_earlier: bool = BoolField(default=True) refresh: bool = BoolField(default=False) scroll_id: str = StringField() - model_events: bool = BoolField() + model_events: bool = BoolField(default=False) + + +class VectorMetricsIterHistogramRequest(Base): + task: str = StringField(required=True) + metric: str = StringField(required=True) + variant: str = StringField(required=True) + model_events: bool = BoolField(default=False) class GetVariantSampleRequest(Base): @@ -110,6 +136,11 @@ class TaskEventsRequest(TaskEventsRequestBase): model_events: bool = BoolField(default=False) +class LegacyLogEventsRequest(TaskEventsRequestBase): + order: Optional[str] = ActualEnumField(LogOrderEnum, default=LogOrderEnum.desc) + scroll_id: str = StringField() + + class LogEventsRequest(TaskEventsRequestBase): batch_size: int = IntField(default=5000) navigate_earlier: bool = BoolField(default=True) @@ -160,6 +191,11 @@ class MultiTaskMetricsRequest(MultiTasksRequestBase): event_type: EventType = ActualEnumField(EventType, default=EventType.all) +class LegacyMultiTaskEventsRequest(MultiTasksRequestBase): + iters: int = IntField(default=1, validators=validators.Min(1)) + scroll_id: str = StringField() + + class MultiTaskPlotsRequest(MultiTasksRequestBase): iters: int = IntField(default=1) scroll_id: str = StringField() @@ -177,6 +213,14 @@ class TaskPlotsRequest(Base): model_events: bool = BoolField(default=False) +class GetScalarMetricDataRequest(Base): + task: str = StringField(required=True) + metric: str = StringField(required=True) + scroll_id: str = StringField() + no_scroll: bool = BoolField(default=False) + model_events: bool = BoolField(default=False) + + class ClearScrollRequest(Base): scroll_id: str = StringField() diff --git a/apiserver/apimodels/models.py b/apiserver/apimodels/models.py index be48464..b30e293 100644 --- a/apiserver/apimodels/models.py +++ b/apiserver/apimodels/models.py @@ -42,6 +42,21 @@ class ModelRequest(models.Base): model = fields.StringField(required=True) +class TaskRequest(models.Base): + task = fields.StringField(required=True) + + +class UpdateForTaskRequest(TaskRequest): + uri = fields.StringField() + iteration = fields.IntField() + override_model_id = fields.StringField() + + +class UpdateModelRequest(ModelRequest): + task = fields.StringField() + iteration = fields.IntField() + + class DeleteModelRequest(ModelRequest): force = fields.BoolField(default=False) delete_external_artifacts = fields.BoolField(default=True) diff --git a/apiserver/apimodels/server.py b/apiserver/apimodels/server.py index d977ab6..5857a8d 100644 --- a/apiserver/apimodels/server.py +++ b/apiserver/apimodels/server.py @@ -6,6 +6,10 @@ class ReportStatsOptionRequest(Base): enabled = BoolField(default=None, nullable=True) +class GetConfigRequest(Base): + path = StringField() + + class ReportStatsOptionResponse(Base): supported = BoolField(default=True) enabled = BoolField() diff --git a/apiserver/apimodels/users.py b/apiserver/apimodels/users.py index e23ea7b..20f392f 100644 --- a/apiserver/apimodels/users.py +++ b/apiserver/apimodels/users.py @@ -4,6 +4,10 @@ from jsonmodels.models import Base from apiserver.apimodels import DictField +class UserRequest(Base): + user = StringField(required=True) + + class CreateRequest(Base): id = StringField(required=True) name = StringField(required=True) diff --git a/apiserver/services/events.py b/apiserver/services/events.py index 4e928ff..0c4f3ec 100644 --- a/apiserver/services/events.py +++ b/apiserver/services/events.py @@ -32,6 +32,14 @@ from apiserver.apimodels.events import ( TaskMetric, MultiTaskPlotsRequest, MultiTaskMetricsRequest, + LegacyLogEventsRequest, + TaskRequest, + GetMetricsAndVariantsRequest, + ModelRequest, + LegacyMetricEventsRequest, + GetScalarMetricDataRequest, + VectorMetricsIterHistogramRequest, + LegacyMultiTaskEventsRequest, ) from apiserver.bll.event import EventBLL from apiserver.bll.event.event_common import EventType, MetricVariants, TaskCompanies @@ -97,15 +105,15 @@ def add_batch(call: APICall, company_id, _): call.result.data = dict(added=added, errors=err_count, errors_info=err_info) -@endpoint("events.get_task_log", required_fields=["task"]) -def get_task_log_v1_5(call, company_id, _): - task_id = call.data["task"] +@endpoint("events.get_task_log") +def get_task_log_v1_5(call, company_id, request: LegacyLogEventsRequest): + task_id = request.task task = task_bll.assert_exists( company_id, task_id, allow_public=True, only=("company", "company_origin") )[0] - order = call.data.get("order") or "desc" - scroll_id = call.data.get("scroll_id") - batch_size = int(call.data.get("batch_size") or 500) + order = request.order + scroll_id = request.scroll_id + batch_size = request.batch_size events, scroll_id, total_events = event_bll.scroll_task_events( task.get_index_company(), task_id, @@ -119,17 +127,17 @@ def get_task_log_v1_5(call, company_id, _): ) -@endpoint("events.get_task_log", min_version="1.7", required_fields=["task"]) -def get_task_log_v1_7(call, company_id, _): - task_id = call.data["task"] +@endpoint("events.get_task_log", min_version="1.7") +def get_task_log_v1_7(call, company_id, request: LegacyLogEventsRequest): + task_id = request.task task = task_bll.assert_exists( company_id, task_id, allow_public=True, only=("company", "company_origin") )[0] - order = call.data.get("order") or "desc" + order = request.order from_ = call.data.get("from") or "head" - scroll_id = call.data.get("scroll_id") - batch_size = int(call.data.get("batch_size") or 500) + scroll_id = request.scroll_id + batch_size = request.batch_size scroll_order = "asc" if (from_ == "head") else "desc" @@ -177,9 +185,9 @@ def get_task_log(call, company_id, request: LogEventsRequest): ) -@endpoint("events.download_task_log", required_fields=["task"]) -def download_task_log(call, company_id, _): - task_id = call.data["task"] +@endpoint("events.download_task_log") +def download_task_log(call, company_id, request: TaskRequest): + task_id = request.task task = task_bll.assert_exists( company_id, task_id, allow_public=True, only=("company", "company_origin") )[0] @@ -257,10 +265,12 @@ def download_task_log(call, company_id, _): call.result.raw_data = generate() -@endpoint("events.get_vector_metrics_and_variants", required_fields=["task"]) -def get_vector_metrics_and_variants(call, company_id, _): - task_id = call.data["task"] - model_events = call.data["model_events"] +@endpoint("events.get_vector_metrics_and_variants") +def get_vector_metrics_and_variants( + call, company_id, request: GetMetricsAndVariantsRequest +): + task_id = request.task + model_events = request.model_events task_or_model = _assert_task_or_model_exists( company_id, task_id, @@ -273,10 +283,12 @@ def get_vector_metrics_and_variants(call, company_id, _): ) -@endpoint("events.get_scalar_metrics_and_variants", required_fields=["task"]) -def get_scalar_metrics_and_variants(call, company_id, _): - task_id = call.data["task"] - model_events = call.data["model_events"] +@endpoint("events.get_scalar_metrics_and_variants") +def get_scalar_metrics_and_variants( + call, company_id, request: GetMetricsAndVariantsRequest +): + task_id = request.task + model_events = request.model_events task_or_model = _assert_task_or_model_exists( company_id, task_id, @@ -292,18 +304,19 @@ def get_scalar_metrics_and_variants(call, company_id, _): # todo: !!! currently returning 10,000 records. should decide on a better way to control it @endpoint( "events.vector_metrics_iter_histogram", - required_fields=["task", "metric", "variant"], ) -def vector_metrics_iter_histogram(call, company_id, _): - task_id = call.data["task"] - model_events = call.data["model_events"] +def vector_metrics_iter_histogram( + call, company_id, request: VectorMetricsIterHistogramRequest +): + task_id = request.task + model_events = request.model_events task_or_model = _assert_task_or_model_exists( company_id, task_id, model_events=model_events, )[0] - metric = call.data["metric"] - variant = call.data["variant"] + metric = request.metric + variant = request.variant iterations, vectors = event_bll.get_vector_metrics_per_iter( task_or_model.get_index_company(), task_id, metric, variant ) @@ -404,13 +417,13 @@ def get_task_events(_, company_id, request: TaskEventsRequest): ) -@endpoint("events.get_scalar_metric_data", required_fields=["task", "metric"]) -def get_scalar_metric_data(call, company_id, _): - task_id = call.data["task"] - metric = call.data["metric"] - scroll_id = call.data.get("scroll_id") - no_scroll = call.data.get("no_scroll", False) - model_events = call.data.get("model_events", False) +@endpoint("events.get_scalar_metric_data") +def get_scalar_metric_data(call, company_id, request: GetScalarMetricDataRequest): + task_id = request.task + metric = request.metric + scroll_id = request.scroll_id + no_scroll = request.no_scroll + model_events = request.model_events task_or_model = _assert_task_or_model_exists( company_id, @@ -435,9 +448,9 @@ def get_scalar_metric_data(call, company_id, _): ) -@endpoint("events.get_task_latest_scalar_values", required_fields=["task"]) -def get_task_latest_scalar_values(call, company_id, _): - task_id = call.data["task"] +@endpoint("events.get_task_latest_scalar_values") +def get_task_latest_scalar_values(call, company_id, request: TaskRequest): + task_id = request.task task = task_bll.assert_exists( company_id, task_id, allow_public=True, only=("company", "company_origin") )[0] @@ -558,11 +571,11 @@ def get_task_single_value_metrics( ) -@endpoint("events.get_multi_task_plots", required_fields=["tasks"]) -def get_multi_task_plots_v1_7(call, company_id, _): - task_ids = call.data["tasks"] - iters = call.data.get("iters", 1) - scroll_id = call.data.get("scroll_id") +@endpoint("events.get_multi_task_plots") +def get_multi_task_plots_v1_7(call, company_id, request: LegacyMultiTaskEventsRequest): + task_ids = request.tasks + iters = request.iters + scroll_id = request.scroll_id companies = _get_task_or_model_index_companies(company_id, task_ids) @@ -644,11 +657,11 @@ def get_multi_task_plots(call, company_id, request: MultiTaskPlotsRequest): ) -@endpoint("events.get_task_plots", required_fields=["task"]) -def get_task_plots_v1_7(call, company_id, _): - task_id = call.data["task"] - iters = call.data.get("iters", 1) - scroll_id = call.data.get("scroll_id") +@endpoint("events.get_task_plots") +def get_task_plots_v1_7(call, company_id, request: LegacyMetricEventsRequest): + task_id = request.task + iters = request.iters + scroll_id = request.scroll_id task = task_bll.assert_exists( company_id, task_id, allow_public=True, only=("company", "company_origin") @@ -766,11 +779,11 @@ def task_plots(call, company_id, request: MetricEventsRequest): ) -@endpoint("events.debug_images", required_fields=["task"]) -def get_debug_images_v1_7(call, company_id, _): - task_id = call.data["task"] - iters = call.data.get("iters") or 1 - scroll_id = call.data.get("scroll_id") +@endpoint("events.debug_images") +def get_debug_images_v1_7(call, company_id, request: LegacyMetricEventsRequest): + task_id = request.task + iters = request.iters + scroll_id = request.scroll_id task = task_bll.assert_exists( company_id, task_id, allow_public=True, only=("company", "company_origin") @@ -803,12 +816,12 @@ def get_debug_images_v1_7(call, company_id, _): ) -@endpoint("events.debug_images", min_version="1.8", required_fields=["task"]) -def get_debug_images_v1_8(call, company_id, _): - task_id = call.data["task"] - iters = call.data.get("iters") or 1 - scroll_id = call.data.get("scroll_id") - model_events = call.data.get("model_events", False) +@endpoint("events.debug_images", min_version="1.8") +def get_debug_images_v1_8(call, company_id, request: LegacyMetricEventsRequest): + task_id = request.task + iters = request.iters + scroll_id = request.scroll_id + model_events = request.model_events tasks_or_model = _assert_task_or_model_exists( company_id, @@ -975,8 +988,7 @@ def get_multi_task_metrics(call: APICall, company_id, request: MultiTaskMetricsR return {"metrics": []} metrics = event_bll.metrics.get_multi_task_metrics( - companies=companies, - event_type=request.event_type + companies=companies, event_type=request.event_type ) res = [ { @@ -985,14 +997,12 @@ def get_multi_task_metrics(call: APICall, company_id, request: MultiTaskMetricsR } for m, vars_ in metrics.items() ] - call.result.data = { - "metrics": sorted(res, key=itemgetter("metric")) - } + call.result.data = {"metrics": sorted(res, key=itemgetter("metric"))} -@endpoint("events.delete_for_task", required_fields=["task"]) -def delete_for_task(call, company_id, _): - task_id = call.data["task"] +@endpoint("events.delete_for_task") +def delete_for_task(call, company_id, request: TaskRequest): + task_id = request.task allow_locked = call.data.get("allow_locked", False) get_task_with_write_access( @@ -1005,9 +1015,9 @@ def delete_for_task(call, company_id, _): ) -@endpoint("events.delete_for_model", required_fields=["model"]) -def delete_for_model(call: APICall, company_id: str, _): - model_id = call.data["model"] +@endpoint("events.delete_for_model") +def delete_for_model(call: APICall, company_id: str, request: ModelRequest): + model_id = request.model allow_locked = call.data.get("allow_locked", False) model_bll.assert_exists(company_id, model_id, return_models=False) diff --git a/apiserver/services/models.py b/apiserver/services/models.py index 742faa2..b476bd6 100644 --- a/apiserver/services/models.py +++ b/apiserver/services/models.py @@ -21,6 +21,10 @@ from apiserver.apimodels.models import ( ModelsPublishManyRequest, ModelsDeleteManyRequest, ModelsGetRequest, + ModelRequest, + TaskRequest, + UpdateForTaskRequest, + UpdateModelRequest, ) from apiserver.apimodels.tasks import UpdateTagsRequest from apiserver.bll.model import ModelBLL, Metadata @@ -67,9 +71,9 @@ def conform_model_data(call: APICall, model_data: Union[Sequence[dict], dict]): unescape_metadata(call, model_data) -@endpoint("models.get_by_id", required_fields=["model"]) -def get_by_id(call: APICall, company_id, _): - model_id = call.data["model"] +@endpoint("models.get_by_id") +def get_by_id(call: APICall, company_id, request: ModelRequest): + model_id = request.model call_data = Metadata.escape_query_parameters(call.data) models = Model.get_many( company=company_id, @@ -87,12 +91,12 @@ def get_by_id(call: APICall, company_id, _): call.result.data = {"model": models[0]} -@endpoint("models.get_by_task_id", required_fields=["task"]) -def get_by_task_id(call: APICall, company_id, _): +@endpoint("models.get_by_task_id") +def get_by_task_id(call: APICall, company_id, request: TaskRequest): if call.requested_endpoint_version > ModelsBackwardsCompatibility.max_version: raise errors.moved_permanently.NotSupported("use models.get_by_id/get_all apis") - task_id = call.data["task"] + task_id = request.task query = dict(id=task_id, company=company_id) task = Task.get(_only=["models"], **query) @@ -157,7 +161,7 @@ def get_by_id_ex(call: APICall, company_id, _): call.result.data = {"models": models} -@endpoint("models.get_all", required_fields=[]) +@endpoint("models.get_all") def get_all(call: APICall, company_id, _): conform_tag_fields(call, call.data) call_data = Metadata.escape_query_parameters(call.data) @@ -236,15 +240,15 @@ def _reset_cached_tags(company: str, projects: Sequence[str]): ) -@endpoint("models.update_for_task", required_fields=["task"]) -def update_for_task(call: APICall, company_id, _): +@endpoint("models.update_for_task") +def update_for_task(call: APICall, company_id, request: UpdateForTaskRequest): if call.requested_endpoint_version > ModelsBackwardsCompatibility.max_version: raise errors.moved_permanently.NotSupported("use tasks.add_or_update_model") - task_id = call.data["task"] - uri = call.data.get("uri") - iteration = call.data.get("iteration") - override_model_id = call.data.get("override_model_id") + task_id = request.task + uri = request.uri + iteration = request.iteration + override_model_id = request.override_model_id if not (uri or override_model_id) or (uri and override_model_id): raise errors.bad_request.MissingRequiredFields( "exactly one field is required", fields=("uri", "override_model_id") @@ -411,9 +415,9 @@ def validate_task(company_id: str, identity: Identity, fields: dict): ) -@endpoint("models.edit", required_fields=["model"], response_data_model=UpdateResponse) -def edit(call: APICall, company_id, _): - model_id = call.data["model"] +@endpoint("models.edit", response_data_model=UpdateResponse) +def edit(call: APICall, company_id, request: UpdateModelRequest): + model_id = request.model model = ModelBLL.get_company_model_by_id(company_id=company_id, model_id=model_id) @@ -428,7 +432,7 @@ def edit(call: APICall, company_id, _): d.update(value) fields[key] = d - iteration = call.data.get("iteration") + iteration = request.iteration task_id = model.task or fields.get("task") if task_id and iteration is not None: TaskBLL.update_statistics( @@ -460,13 +464,9 @@ def edit(call: APICall, company_id, _): call.result.data_model = UpdateResponse(updated=0) -def _update_model(call: APICall, company_id, model_id=None): - model_id = model_id or call.data["model"] - +def _update_model(call: APICall, company_id, model_id): model = ModelBLL.get_company_model_by_id(company_id=company_id, model_id=model_id) - data = prepare_update_fields(call, company_id, call.data) - task_id = data.get("task") iteration = data.get("iteration") if task_id and iteration is not None: @@ -502,11 +502,9 @@ def _update_model(call: APICall, company_id, model_id=None): return UpdateResponse(updated=updated_count, fields=updated_fields) -@endpoint( - "models.update", required_fields=["model"], response_data_model=UpdateResponse -) -def update(call, company_id, _): - call.result.data_model = _update_model(call, company_id) +@endpoint("models.update", response_data_model=UpdateResponse) +def update(call, company_id, request: UpdateModelRequest): + call.result.data_model = _update_model(call, company_id, model_id=request.model) @endpoint( @@ -629,7 +627,9 @@ def archive_many(call: APICall, company_id, request: BatchRequest): ) def unarchive_many(call: APICall, company_id, request: BatchRequest): results, failures = run_batch_operation( - func=partial(ModelBLL.unarchive_model, company_id=company_id, user_id=call.identity.user), + func=partial( + ModelBLL.unarchive_model, company_id=company_id, user_id=call.identity.user + ), ids=request.ids, ) call.result.data_model = BatchResponse( diff --git a/apiserver/services/projects.py b/apiserver/services/projects.py index ee6c75e..5d45519 100644 --- a/apiserver/services/projects.py +++ b/apiserver/services/projects.py @@ -59,13 +59,12 @@ create_fields = { } -@endpoint("projects.get_by_id", required_fields=["project"]) -def get_by_id(call): - assert isinstance(call, APICall) - project_id = call.data["project"] +@endpoint("projects.get_by_id") +def get_by_id(call: APICall, company: str, request: ProjectRequest): + project_id = request.project with translate_errors_context(): - query = Q(id=project_id) & get_company_or_none_constraint(call.identity.company) + query = Q(id=project_id) & get_company_or_none_constraint(company) project = Project.objects(query).first() if not project: raise errors.bad_request.InvalidProjectId(id=project_id) @@ -147,8 +146,10 @@ def get_all_ex(call: APICall, company_id: str, request: ProjectsGetRequest): requested_ids = data.get("id") if isinstance(requested_ids, str): requested_ids = [requested_ids] + _adjust_search_parameters( - data, shallow_search=request.shallow_search, + data, + shallow_search=request.shallow_search, ) selected_project_ids = None if request.active_users or request.children_type: @@ -246,7 +247,9 @@ def get_all_ex(call: APICall, company_id: str, request: ProjectsGetRequest): if request.include_dataset_stats: dataset_stats = project_bll.get_dataset_stats( - company=company_id, project_ids=project_ids, users=request.active_users, + company=company_id, + project_ids=project_ids, + users=request.active_users, ) for project in projects: project["dataset_stats"] = dataset_stats.get(project["id"]) @@ -255,15 +258,16 @@ def get_all_ex(call: APICall, company_id: str, request: ProjectsGetRequest): @endpoint("projects.get_all") -def get_all(call: APICall): +def get_all(call: APICall, company: str, _): data = call.data conform_tag_fields(call, data) _adjust_search_parameters( - data, shallow_search=data.get("shallow_search", False), + data, + shallow_search=data.get("shallow_search", False), ) ret_params = {} projects = Project.get_many( - company=call.identity.company, + company=company, query_dict=data, query=_hidden_query( search_hidden=data.get("search_hidden"), ids=data.get("id") @@ -277,9 +281,11 @@ def get_all(call: APICall): @endpoint( - "projects.create", required_fields=["name"], response_data_model=IdResponse, + "projects.create", + required_fields=["name"], + response_data_model=IdResponse, ) -def create(call: APICall): +def create(call: APICall, company: str, _): identity = call.identity with translate_errors_context(): @@ -288,15 +294,15 @@ def create(call: APICall): return IdResponse( id=ProjectBLL.create( - user=identity.user, company=identity.company, **fields, + user=identity.user, + company=company, + **fields, ) ) -@endpoint( - "projects.update", required_fields=["project"], response_data_model=UpdateResponse -) -def update(call: APICall): +@endpoint("projects.update", response_data_model=UpdateResponse) +def update(call: APICall, company: str, request: ProjectRequest): """ update @@ -309,9 +315,7 @@ def update(call: APICall): call.data, create_fields, Project.get_fields(), discard_none_values=False ) conform_tag_fields(call, fields, validate=True) - updated = ProjectBLL.update( - company=call.identity.company, project_id=call.data["project"], **fields - ) + updated = ProjectBLL.update(company=company, project_id=request.project, **fields) conform_output_tags(call, fields) call.result.data_model = UpdateResponse(updated=updated, fields=fields) @@ -375,7 +379,6 @@ def delete(call: APICall, company_id: str, request: DeleteRequest): def get_unique_metric_variants( call: APICall, company_id: str, request: GetUniqueMetricsRequest ): - metrics = project_queries.get_unique_metric_variants( company_id, [request.project] if request.project else None, @@ -429,7 +432,6 @@ def get_model_metadata_values( request_data_model=GetParamsRequest, ) def get_hyper_parameters(call: APICall, company_id: str, request: GetParamsRequest): - total, remaining, parameters = project_queries.get_aggregated_project_parameters( company_id, project_ids=[request.project] if request.project else None, diff --git a/apiserver/services/server/__init__.py b/apiserver/services/server/__init__.py index 52f8331..4230d87 100644 --- a/apiserver/services/server/__init__.py +++ b/apiserver/services/server/__init__.py @@ -3,7 +3,11 @@ from datetime import datetime from pyhocon.config_tree import NoneValue from apiserver.apierrors import errors -from apiserver.apimodels.server import ReportStatsOptionRequest, ReportStatsOptionResponse +from apiserver.apimodels.server import ( + ReportStatsOptionRequest, + ReportStatsOptionResponse, + GetConfigRequest, +) from apiserver.bll.statistics.stats_reporter import StatisticsReporter from apiserver.config_repo import config from apiserver.config.info import get_version, get_build_number, get_commit_number @@ -22,8 +26,8 @@ def get_stats(call: APICall): @endpoint("server.config") -def get_config(call: APICall): - path = call.data.get("path") +def get_config(call: APICall, _, request: GetConfigRequest): + path = request.path if path: c = dict(config.get(path)) else: diff --git a/apiserver/services/users.py b/apiserver/services/users.py index e2a9365..699a4c0 100644 --- a/apiserver/services/users.py +++ b/apiserver/services/users.py @@ -7,7 +7,11 @@ from mongoengine import Q from apiserver.apierrors import errors from apiserver.apimodels.base import UpdateResponse -from apiserver.apimodels.users import CreateRequest, SetPreferencesRequest +from apiserver.apimodels.users import ( + CreateRequest, + SetPreferencesRequest, + UserRequest, +) from apiserver.bll.project import ProjectBLL from apiserver.bll.user import UserBLL from apiserver.config_repo import config @@ -48,13 +52,13 @@ def get_user(call, company_id, user_id, only=None): return res.to_proper_dict() -@endpoint("users.get_by_id", required_fields=["user"]) -def get_by_id(call: APICall, company_id, _): - user_id = call.data["user"] +@endpoint("users.get_by_id") +def get_by_id(call: APICall, company_id, request: UserRequest): + user_id = request.user call.result.data = {"user": get_user(call, company_id, user_id)} -@endpoint("users.get_all_ex", required_fields=[]) +@endpoint("users.get_all_ex") def get_all_ex(call: APICall, company_id, _): with translate_errors_context("retrieving users"): res = User.get_many_with_join(company=company_id, query_dict=call.data) @@ -62,7 +66,7 @@ def get_all_ex(call: APICall, company_id, _): call.result.data = {"users": res} -@endpoint("users.get_all_ex", min_version="2.8", required_fields=[]) +@endpoint("users.get_all_ex", min_version="2.8") def get_all_ex2_8(call: APICall, company_id, _): with translate_errors_context("retrieving users"): data = call.data @@ -83,7 +87,7 @@ def get_all_ex2_8(call: APICall, company_id, _): call.result.data = {"users": res} -@endpoint("users.get_all", required_fields=[]) +@endpoint("users.get_all") def get_all(call: APICall, company_id, _): with translate_errors_context("retrieving users"): res = User.get_many( @@ -138,9 +142,9 @@ def create(call: APICall): UserBLL.create(call.data_model) -@endpoint("users.delete", required_fields=["user"]) -def delete(call: APICall): - UserBLL.delete(call.data["user"]) +@endpoint("users.delete") +def delete(_: APICall, __, request: UserRequest): + UserBLL.delete(request.user) def update_user(user_id, company_id, data: dict) -> Tuple[int, dict]: @@ -159,9 +163,9 @@ def update_user(user_id, company_id, data: dict) -> Tuple[int, dict]: return User.safe_update(company_id, user_id, partial_update_dict) -@endpoint("users.update", required_fields=["user"], response_data_model=UpdateResponse) -def update(call, company_id, _): - user_id = call.data["user"] +@endpoint("users.update", response_data_model=UpdateResponse) +def update(call, company_id, request: UserRequest): + user_id = request.user update_count, updated_fields = update_user(user_id, company_id, call.data) call.result.data_model = UpdateResponse(updated=update_count, fields=updated_fields)