mirror of
https://github.com/clearml/clearml-server
synced 2025-04-05 13:35:02 +00:00
Add input parameters check to multiple APIs
This commit is contained in:
parent
702b6dc9c8
commit
a47e65d974
@ -13,6 +13,14 @@ from apiserver.config_repo import config
|
||||
from apiserver.utilities.stringenum import StringEnum
|
||||
|
||||
|
||||
class TaskRequest(Base):
|
||||
task: str = StringField(required=True)
|
||||
|
||||
|
||||
class ModelRequest(Base):
|
||||
model: str = StringField(required=True)
|
||||
|
||||
|
||||
class HistogramRequestBase(Base):
|
||||
samples: int = IntField(default=2000, validators=[Min(1), Max(6000)])
|
||||
key: ScalarKeyEnum = ActualEnumField(ScalarKeyEnum, default=ScalarKeyEnum.iter)
|
||||
@ -29,6 +37,11 @@ class ScalarMetricsIterHistogramRequest(HistogramRequestBase):
|
||||
model_events: bool = BoolField(default=False)
|
||||
|
||||
|
||||
class GetMetricsAndVariantsRequest(Base):
|
||||
task: str = StringField(required=True)
|
||||
model_events: bool = BoolField(default=False)
|
||||
|
||||
|
||||
class MultiTaskScalarMetricsIterHistogramRequest(HistogramRequestBase):
|
||||
tasks: Sequence[str] = ListField(
|
||||
items_types=str,
|
||||
@ -51,6 +64,12 @@ class TaskMetric(Base):
|
||||
variants: Sequence[str] = ListField(items_types=str)
|
||||
|
||||
|
||||
class LegacyMetricEventsRequest(TaskRequest):
|
||||
iters: int = IntField(default=1, validators=validators.Min(1))
|
||||
scroll_id: str = StringField()
|
||||
model_events: bool = BoolField(default=False)
|
||||
|
||||
|
||||
class MetricEventsRequest(Base):
|
||||
metrics: Sequence[TaskMetric] = ListField(
|
||||
items_types=TaskMetric, validators=[Length(minimum_value=1)]
|
||||
@ -59,7 +78,14 @@ class MetricEventsRequest(Base):
|
||||
navigate_earlier: bool = BoolField(default=True)
|
||||
refresh: bool = BoolField(default=False)
|
||||
scroll_id: str = StringField()
|
||||
model_events: bool = BoolField()
|
||||
model_events: bool = BoolField(default=False)
|
||||
|
||||
|
||||
class VectorMetricsIterHistogramRequest(Base):
|
||||
task: str = StringField(required=True)
|
||||
metric: str = StringField(required=True)
|
||||
variant: str = StringField(required=True)
|
||||
model_events: bool = BoolField(default=False)
|
||||
|
||||
|
||||
class GetVariantSampleRequest(Base):
|
||||
@ -110,6 +136,11 @@ class TaskEventsRequest(TaskEventsRequestBase):
|
||||
model_events: bool = BoolField(default=False)
|
||||
|
||||
|
||||
class LegacyLogEventsRequest(TaskEventsRequestBase):
|
||||
order: Optional[str] = ActualEnumField(LogOrderEnum, default=LogOrderEnum.desc)
|
||||
scroll_id: str = StringField()
|
||||
|
||||
|
||||
class LogEventsRequest(TaskEventsRequestBase):
|
||||
batch_size: int = IntField(default=5000)
|
||||
navigate_earlier: bool = BoolField(default=True)
|
||||
@ -160,6 +191,11 @@ class MultiTaskMetricsRequest(MultiTasksRequestBase):
|
||||
event_type: EventType = ActualEnumField(EventType, default=EventType.all)
|
||||
|
||||
|
||||
class LegacyMultiTaskEventsRequest(MultiTasksRequestBase):
|
||||
iters: int = IntField(default=1, validators=validators.Min(1))
|
||||
scroll_id: str = StringField()
|
||||
|
||||
|
||||
class MultiTaskPlotsRequest(MultiTasksRequestBase):
|
||||
iters: int = IntField(default=1)
|
||||
scroll_id: str = StringField()
|
||||
@ -177,6 +213,14 @@ class TaskPlotsRequest(Base):
|
||||
model_events: bool = BoolField(default=False)
|
||||
|
||||
|
||||
class GetScalarMetricDataRequest(Base):
|
||||
task: str = StringField(required=True)
|
||||
metric: str = StringField(required=True)
|
||||
scroll_id: str = StringField()
|
||||
no_scroll: bool = BoolField(default=False)
|
||||
model_events: bool = BoolField(default=False)
|
||||
|
||||
|
||||
class ClearScrollRequest(Base):
|
||||
scroll_id: str = StringField()
|
||||
|
||||
|
@ -42,6 +42,21 @@ class ModelRequest(models.Base):
|
||||
model = fields.StringField(required=True)
|
||||
|
||||
|
||||
class TaskRequest(models.Base):
|
||||
task = fields.StringField(required=True)
|
||||
|
||||
|
||||
class UpdateForTaskRequest(TaskRequest):
|
||||
uri = fields.StringField()
|
||||
iteration = fields.IntField()
|
||||
override_model_id = fields.StringField()
|
||||
|
||||
|
||||
class UpdateModelRequest(ModelRequest):
|
||||
task = fields.StringField()
|
||||
iteration = fields.IntField()
|
||||
|
||||
|
||||
class DeleteModelRequest(ModelRequest):
|
||||
force = fields.BoolField(default=False)
|
||||
delete_external_artifacts = fields.BoolField(default=True)
|
||||
|
@ -6,6 +6,10 @@ class ReportStatsOptionRequest(Base):
|
||||
enabled = BoolField(default=None, nullable=True)
|
||||
|
||||
|
||||
class GetConfigRequest(Base):
|
||||
path = StringField()
|
||||
|
||||
|
||||
class ReportStatsOptionResponse(Base):
|
||||
supported = BoolField(default=True)
|
||||
enabled = BoolField()
|
||||
|
@ -4,6 +4,10 @@ from jsonmodels.models import Base
|
||||
from apiserver.apimodels import DictField
|
||||
|
||||
|
||||
class UserRequest(Base):
|
||||
user = StringField(required=True)
|
||||
|
||||
|
||||
class CreateRequest(Base):
|
||||
id = StringField(required=True)
|
||||
name = StringField(required=True)
|
||||
|
@ -32,6 +32,14 @@ from apiserver.apimodels.events import (
|
||||
TaskMetric,
|
||||
MultiTaskPlotsRequest,
|
||||
MultiTaskMetricsRequest,
|
||||
LegacyLogEventsRequest,
|
||||
TaskRequest,
|
||||
GetMetricsAndVariantsRequest,
|
||||
ModelRequest,
|
||||
LegacyMetricEventsRequest,
|
||||
GetScalarMetricDataRequest,
|
||||
VectorMetricsIterHistogramRequest,
|
||||
LegacyMultiTaskEventsRequest,
|
||||
)
|
||||
from apiserver.bll.event import EventBLL
|
||||
from apiserver.bll.event.event_common import EventType, MetricVariants, TaskCompanies
|
||||
@ -97,15 +105,15 @@ def add_batch(call: APICall, company_id, _):
|
||||
call.result.data = dict(added=added, errors=err_count, errors_info=err_info)
|
||||
|
||||
|
||||
@endpoint("events.get_task_log", required_fields=["task"])
|
||||
def get_task_log_v1_5(call, company_id, _):
|
||||
task_id = call.data["task"]
|
||||
@endpoint("events.get_task_log")
|
||||
def get_task_log_v1_5(call, company_id, request: LegacyLogEventsRequest):
|
||||
task_id = request.task
|
||||
task = task_bll.assert_exists(
|
||||
company_id, task_id, allow_public=True, only=("company", "company_origin")
|
||||
)[0]
|
||||
order = call.data.get("order") or "desc"
|
||||
scroll_id = call.data.get("scroll_id")
|
||||
batch_size = int(call.data.get("batch_size") or 500)
|
||||
order = request.order
|
||||
scroll_id = request.scroll_id
|
||||
batch_size = request.batch_size
|
||||
events, scroll_id, total_events = event_bll.scroll_task_events(
|
||||
task.get_index_company(),
|
||||
task_id,
|
||||
@ -119,17 +127,17 @@ def get_task_log_v1_5(call, company_id, _):
|
||||
)
|
||||
|
||||
|
||||
@endpoint("events.get_task_log", min_version="1.7", required_fields=["task"])
|
||||
def get_task_log_v1_7(call, company_id, _):
|
||||
task_id = call.data["task"]
|
||||
@endpoint("events.get_task_log", min_version="1.7")
|
||||
def get_task_log_v1_7(call, company_id, request: LegacyLogEventsRequest):
|
||||
task_id = request.task
|
||||
task = task_bll.assert_exists(
|
||||
company_id, task_id, allow_public=True, only=("company", "company_origin")
|
||||
)[0]
|
||||
|
||||
order = call.data.get("order") or "desc"
|
||||
order = request.order
|
||||
from_ = call.data.get("from") or "head"
|
||||
scroll_id = call.data.get("scroll_id")
|
||||
batch_size = int(call.data.get("batch_size") or 500)
|
||||
scroll_id = request.scroll_id
|
||||
batch_size = request.batch_size
|
||||
|
||||
scroll_order = "asc" if (from_ == "head") else "desc"
|
||||
|
||||
@ -177,9 +185,9 @@ def get_task_log(call, company_id, request: LogEventsRequest):
|
||||
)
|
||||
|
||||
|
||||
@endpoint("events.download_task_log", required_fields=["task"])
|
||||
def download_task_log(call, company_id, _):
|
||||
task_id = call.data["task"]
|
||||
@endpoint("events.download_task_log")
|
||||
def download_task_log(call, company_id, request: TaskRequest):
|
||||
task_id = request.task
|
||||
task = task_bll.assert_exists(
|
||||
company_id, task_id, allow_public=True, only=("company", "company_origin")
|
||||
)[0]
|
||||
@ -257,10 +265,12 @@ def download_task_log(call, company_id, _):
|
||||
call.result.raw_data = generate()
|
||||
|
||||
|
||||
@endpoint("events.get_vector_metrics_and_variants", required_fields=["task"])
|
||||
def get_vector_metrics_and_variants(call, company_id, _):
|
||||
task_id = call.data["task"]
|
||||
model_events = call.data["model_events"]
|
||||
@endpoint("events.get_vector_metrics_and_variants")
|
||||
def get_vector_metrics_and_variants(
|
||||
call, company_id, request: GetMetricsAndVariantsRequest
|
||||
):
|
||||
task_id = request.task
|
||||
model_events = request.model_events
|
||||
task_or_model = _assert_task_or_model_exists(
|
||||
company_id,
|
||||
task_id,
|
||||
@ -273,10 +283,12 @@ def get_vector_metrics_and_variants(call, company_id, _):
|
||||
)
|
||||
|
||||
|
||||
@endpoint("events.get_scalar_metrics_and_variants", required_fields=["task"])
|
||||
def get_scalar_metrics_and_variants(call, company_id, _):
|
||||
task_id = call.data["task"]
|
||||
model_events = call.data["model_events"]
|
||||
@endpoint("events.get_scalar_metrics_and_variants")
|
||||
def get_scalar_metrics_and_variants(
|
||||
call, company_id, request: GetMetricsAndVariantsRequest
|
||||
):
|
||||
task_id = request.task
|
||||
model_events = request.model_events
|
||||
task_or_model = _assert_task_or_model_exists(
|
||||
company_id,
|
||||
task_id,
|
||||
@ -292,18 +304,19 @@ def get_scalar_metrics_and_variants(call, company_id, _):
|
||||
# todo: !!! currently returning 10,000 records. should decide on a better way to control it
|
||||
@endpoint(
|
||||
"events.vector_metrics_iter_histogram",
|
||||
required_fields=["task", "metric", "variant"],
|
||||
)
|
||||
def vector_metrics_iter_histogram(call, company_id, _):
|
||||
task_id = call.data["task"]
|
||||
model_events = call.data["model_events"]
|
||||
def vector_metrics_iter_histogram(
|
||||
call, company_id, request: VectorMetricsIterHistogramRequest
|
||||
):
|
||||
task_id = request.task
|
||||
model_events = request.model_events
|
||||
task_or_model = _assert_task_or_model_exists(
|
||||
company_id,
|
||||
task_id,
|
||||
model_events=model_events,
|
||||
)[0]
|
||||
metric = call.data["metric"]
|
||||
variant = call.data["variant"]
|
||||
metric = request.metric
|
||||
variant = request.variant
|
||||
iterations, vectors = event_bll.get_vector_metrics_per_iter(
|
||||
task_or_model.get_index_company(), task_id, metric, variant
|
||||
)
|
||||
@ -404,13 +417,13 @@ def get_task_events(_, company_id, request: TaskEventsRequest):
|
||||
)
|
||||
|
||||
|
||||
@endpoint("events.get_scalar_metric_data", required_fields=["task", "metric"])
|
||||
def get_scalar_metric_data(call, company_id, _):
|
||||
task_id = call.data["task"]
|
||||
metric = call.data["metric"]
|
||||
scroll_id = call.data.get("scroll_id")
|
||||
no_scroll = call.data.get("no_scroll", False)
|
||||
model_events = call.data.get("model_events", False)
|
||||
@endpoint("events.get_scalar_metric_data")
|
||||
def get_scalar_metric_data(call, company_id, request: GetScalarMetricDataRequest):
|
||||
task_id = request.task
|
||||
metric = request.metric
|
||||
scroll_id = request.scroll_id
|
||||
no_scroll = request.no_scroll
|
||||
model_events = request.model_events
|
||||
|
||||
task_or_model = _assert_task_or_model_exists(
|
||||
company_id,
|
||||
@ -435,9 +448,9 @@ def get_scalar_metric_data(call, company_id, _):
|
||||
)
|
||||
|
||||
|
||||
@endpoint("events.get_task_latest_scalar_values", required_fields=["task"])
|
||||
def get_task_latest_scalar_values(call, company_id, _):
|
||||
task_id = call.data["task"]
|
||||
@endpoint("events.get_task_latest_scalar_values")
|
||||
def get_task_latest_scalar_values(call, company_id, request: TaskRequest):
|
||||
task_id = request.task
|
||||
task = task_bll.assert_exists(
|
||||
company_id, task_id, allow_public=True, only=("company", "company_origin")
|
||||
)[0]
|
||||
@ -558,11 +571,11 @@ def get_task_single_value_metrics(
|
||||
)
|
||||
|
||||
|
||||
@endpoint("events.get_multi_task_plots", required_fields=["tasks"])
|
||||
def get_multi_task_plots_v1_7(call, company_id, _):
|
||||
task_ids = call.data["tasks"]
|
||||
iters = call.data.get("iters", 1)
|
||||
scroll_id = call.data.get("scroll_id")
|
||||
@endpoint("events.get_multi_task_plots")
|
||||
def get_multi_task_plots_v1_7(call, company_id, request: LegacyMultiTaskEventsRequest):
|
||||
task_ids = request.tasks
|
||||
iters = request.iters
|
||||
scroll_id = request.scroll_id
|
||||
|
||||
companies = _get_task_or_model_index_companies(company_id, task_ids)
|
||||
|
||||
@ -644,11 +657,11 @@ def get_multi_task_plots(call, company_id, request: MultiTaskPlotsRequest):
|
||||
)
|
||||
|
||||
|
||||
@endpoint("events.get_task_plots", required_fields=["task"])
|
||||
def get_task_plots_v1_7(call, company_id, _):
|
||||
task_id = call.data["task"]
|
||||
iters = call.data.get("iters", 1)
|
||||
scroll_id = call.data.get("scroll_id")
|
||||
@endpoint("events.get_task_plots")
|
||||
def get_task_plots_v1_7(call, company_id, request: LegacyMetricEventsRequest):
|
||||
task_id = request.task
|
||||
iters = request.iters
|
||||
scroll_id = request.scroll_id
|
||||
|
||||
task = task_bll.assert_exists(
|
||||
company_id, task_id, allow_public=True, only=("company", "company_origin")
|
||||
@ -766,11 +779,11 @@ def task_plots(call, company_id, request: MetricEventsRequest):
|
||||
)
|
||||
|
||||
|
||||
@endpoint("events.debug_images", required_fields=["task"])
|
||||
def get_debug_images_v1_7(call, company_id, _):
|
||||
task_id = call.data["task"]
|
||||
iters = call.data.get("iters") or 1
|
||||
scroll_id = call.data.get("scroll_id")
|
||||
@endpoint("events.debug_images")
|
||||
def get_debug_images_v1_7(call, company_id, request: LegacyMetricEventsRequest):
|
||||
task_id = request.task
|
||||
iters = request.iters
|
||||
scroll_id = request.scroll_id
|
||||
|
||||
task = task_bll.assert_exists(
|
||||
company_id, task_id, allow_public=True, only=("company", "company_origin")
|
||||
@ -803,12 +816,12 @@ def get_debug_images_v1_7(call, company_id, _):
|
||||
)
|
||||
|
||||
|
||||
@endpoint("events.debug_images", min_version="1.8", required_fields=["task"])
|
||||
def get_debug_images_v1_8(call, company_id, _):
|
||||
task_id = call.data["task"]
|
||||
iters = call.data.get("iters") or 1
|
||||
scroll_id = call.data.get("scroll_id")
|
||||
model_events = call.data.get("model_events", False)
|
||||
@endpoint("events.debug_images", min_version="1.8")
|
||||
def get_debug_images_v1_8(call, company_id, request: LegacyMetricEventsRequest):
|
||||
task_id = request.task
|
||||
iters = request.iters
|
||||
scroll_id = request.scroll_id
|
||||
model_events = request.model_events
|
||||
|
||||
tasks_or_model = _assert_task_or_model_exists(
|
||||
company_id,
|
||||
@ -975,8 +988,7 @@ def get_multi_task_metrics(call: APICall, company_id, request: MultiTaskMetricsR
|
||||
return {"metrics": []}
|
||||
|
||||
metrics = event_bll.metrics.get_multi_task_metrics(
|
||||
companies=companies,
|
||||
event_type=request.event_type
|
||||
companies=companies, event_type=request.event_type
|
||||
)
|
||||
res = [
|
||||
{
|
||||
@ -985,14 +997,12 @@ def get_multi_task_metrics(call: APICall, company_id, request: MultiTaskMetricsR
|
||||
}
|
||||
for m, vars_ in metrics.items()
|
||||
]
|
||||
call.result.data = {
|
||||
"metrics": sorted(res, key=itemgetter("metric"))
|
||||
}
|
||||
call.result.data = {"metrics": sorted(res, key=itemgetter("metric"))}
|
||||
|
||||
|
||||
@endpoint("events.delete_for_task", required_fields=["task"])
|
||||
def delete_for_task(call, company_id, _):
|
||||
task_id = call.data["task"]
|
||||
@endpoint("events.delete_for_task")
|
||||
def delete_for_task(call, company_id, request: TaskRequest):
|
||||
task_id = request.task
|
||||
allow_locked = call.data.get("allow_locked", False)
|
||||
|
||||
get_task_with_write_access(
|
||||
@ -1005,9 +1015,9 @@ def delete_for_task(call, company_id, _):
|
||||
)
|
||||
|
||||
|
||||
@endpoint("events.delete_for_model", required_fields=["model"])
|
||||
def delete_for_model(call: APICall, company_id: str, _):
|
||||
model_id = call.data["model"]
|
||||
@endpoint("events.delete_for_model")
|
||||
def delete_for_model(call: APICall, company_id: str, request: ModelRequest):
|
||||
model_id = request.model
|
||||
allow_locked = call.data.get("allow_locked", False)
|
||||
|
||||
model_bll.assert_exists(company_id, model_id, return_models=False)
|
||||
|
@ -21,6 +21,10 @@ from apiserver.apimodels.models import (
|
||||
ModelsPublishManyRequest,
|
||||
ModelsDeleteManyRequest,
|
||||
ModelsGetRequest,
|
||||
ModelRequest,
|
||||
TaskRequest,
|
||||
UpdateForTaskRequest,
|
||||
UpdateModelRequest,
|
||||
)
|
||||
from apiserver.apimodels.tasks import UpdateTagsRequest
|
||||
from apiserver.bll.model import ModelBLL, Metadata
|
||||
@ -67,9 +71,9 @@ def conform_model_data(call: APICall, model_data: Union[Sequence[dict], dict]):
|
||||
unescape_metadata(call, model_data)
|
||||
|
||||
|
||||
@endpoint("models.get_by_id", required_fields=["model"])
|
||||
def get_by_id(call: APICall, company_id, _):
|
||||
model_id = call.data["model"]
|
||||
@endpoint("models.get_by_id")
|
||||
def get_by_id(call: APICall, company_id, request: ModelRequest):
|
||||
model_id = request.model
|
||||
call_data = Metadata.escape_query_parameters(call.data)
|
||||
models = Model.get_many(
|
||||
company=company_id,
|
||||
@ -87,12 +91,12 @@ def get_by_id(call: APICall, company_id, _):
|
||||
call.result.data = {"model": models[0]}
|
||||
|
||||
|
||||
@endpoint("models.get_by_task_id", required_fields=["task"])
|
||||
def get_by_task_id(call: APICall, company_id, _):
|
||||
@endpoint("models.get_by_task_id")
|
||||
def get_by_task_id(call: APICall, company_id, request: TaskRequest):
|
||||
if call.requested_endpoint_version > ModelsBackwardsCompatibility.max_version:
|
||||
raise errors.moved_permanently.NotSupported("use models.get_by_id/get_all apis")
|
||||
|
||||
task_id = call.data["task"]
|
||||
task_id = request.task
|
||||
|
||||
query = dict(id=task_id, company=company_id)
|
||||
task = Task.get(_only=["models"], **query)
|
||||
@ -157,7 +161,7 @@ def get_by_id_ex(call: APICall, company_id, _):
|
||||
call.result.data = {"models": models}
|
||||
|
||||
|
||||
@endpoint("models.get_all", required_fields=[])
|
||||
@endpoint("models.get_all")
|
||||
def get_all(call: APICall, company_id, _):
|
||||
conform_tag_fields(call, call.data)
|
||||
call_data = Metadata.escape_query_parameters(call.data)
|
||||
@ -236,15 +240,15 @@ def _reset_cached_tags(company: str, projects: Sequence[str]):
|
||||
)
|
||||
|
||||
|
||||
@endpoint("models.update_for_task", required_fields=["task"])
|
||||
def update_for_task(call: APICall, company_id, _):
|
||||
@endpoint("models.update_for_task")
|
||||
def update_for_task(call: APICall, company_id, request: UpdateForTaskRequest):
|
||||
if call.requested_endpoint_version > ModelsBackwardsCompatibility.max_version:
|
||||
raise errors.moved_permanently.NotSupported("use tasks.add_or_update_model")
|
||||
|
||||
task_id = call.data["task"]
|
||||
uri = call.data.get("uri")
|
||||
iteration = call.data.get("iteration")
|
||||
override_model_id = call.data.get("override_model_id")
|
||||
task_id = request.task
|
||||
uri = request.uri
|
||||
iteration = request.iteration
|
||||
override_model_id = request.override_model_id
|
||||
if not (uri or override_model_id) or (uri and override_model_id):
|
||||
raise errors.bad_request.MissingRequiredFields(
|
||||
"exactly one field is required", fields=("uri", "override_model_id")
|
||||
@ -411,9 +415,9 @@ def validate_task(company_id: str, identity: Identity, fields: dict):
|
||||
)
|
||||
|
||||
|
||||
@endpoint("models.edit", required_fields=["model"], response_data_model=UpdateResponse)
|
||||
def edit(call: APICall, company_id, _):
|
||||
model_id = call.data["model"]
|
||||
@endpoint("models.edit", response_data_model=UpdateResponse)
|
||||
def edit(call: APICall, company_id, request: UpdateModelRequest):
|
||||
model_id = request.model
|
||||
|
||||
model = ModelBLL.get_company_model_by_id(company_id=company_id, model_id=model_id)
|
||||
|
||||
@ -428,7 +432,7 @@ def edit(call: APICall, company_id, _):
|
||||
d.update(value)
|
||||
fields[key] = d
|
||||
|
||||
iteration = call.data.get("iteration")
|
||||
iteration = request.iteration
|
||||
task_id = model.task or fields.get("task")
|
||||
if task_id and iteration is not None:
|
||||
TaskBLL.update_statistics(
|
||||
@ -460,13 +464,9 @@ def edit(call: APICall, company_id, _):
|
||||
call.result.data_model = UpdateResponse(updated=0)
|
||||
|
||||
|
||||
def _update_model(call: APICall, company_id, model_id=None):
|
||||
model_id = model_id or call.data["model"]
|
||||
|
||||
def _update_model(call: APICall, company_id, model_id):
|
||||
model = ModelBLL.get_company_model_by_id(company_id=company_id, model_id=model_id)
|
||||
|
||||
data = prepare_update_fields(call, company_id, call.data)
|
||||
|
||||
task_id = data.get("task")
|
||||
iteration = data.get("iteration")
|
||||
if task_id and iteration is not None:
|
||||
@ -502,11 +502,9 @@ def _update_model(call: APICall, company_id, model_id=None):
|
||||
return UpdateResponse(updated=updated_count, fields=updated_fields)
|
||||
|
||||
|
||||
@endpoint(
|
||||
"models.update", required_fields=["model"], response_data_model=UpdateResponse
|
||||
)
|
||||
def update(call, company_id, _):
|
||||
call.result.data_model = _update_model(call, company_id)
|
||||
@endpoint("models.update", response_data_model=UpdateResponse)
|
||||
def update(call, company_id, request: UpdateModelRequest):
|
||||
call.result.data_model = _update_model(call, company_id, model_id=request.model)
|
||||
|
||||
|
||||
@endpoint(
|
||||
@ -629,7 +627,9 @@ def archive_many(call: APICall, company_id, request: BatchRequest):
|
||||
)
|
||||
def unarchive_many(call: APICall, company_id, request: BatchRequest):
|
||||
results, failures = run_batch_operation(
|
||||
func=partial(ModelBLL.unarchive_model, company_id=company_id, user_id=call.identity.user),
|
||||
func=partial(
|
||||
ModelBLL.unarchive_model, company_id=company_id, user_id=call.identity.user
|
||||
),
|
||||
ids=request.ids,
|
||||
)
|
||||
call.result.data_model = BatchResponse(
|
||||
|
@ -59,13 +59,12 @@ create_fields = {
|
||||
}
|
||||
|
||||
|
||||
@endpoint("projects.get_by_id", required_fields=["project"])
|
||||
def get_by_id(call):
|
||||
assert isinstance(call, APICall)
|
||||
project_id = call.data["project"]
|
||||
@endpoint("projects.get_by_id")
|
||||
def get_by_id(call: APICall, company: str, request: ProjectRequest):
|
||||
project_id = request.project
|
||||
|
||||
with translate_errors_context():
|
||||
query = Q(id=project_id) & get_company_or_none_constraint(call.identity.company)
|
||||
query = Q(id=project_id) & get_company_or_none_constraint(company)
|
||||
project = Project.objects(query).first()
|
||||
if not project:
|
||||
raise errors.bad_request.InvalidProjectId(id=project_id)
|
||||
@ -147,8 +146,10 @@ def get_all_ex(call: APICall, company_id: str, request: ProjectsGetRequest):
|
||||
requested_ids = data.get("id")
|
||||
if isinstance(requested_ids, str):
|
||||
requested_ids = [requested_ids]
|
||||
|
||||
_adjust_search_parameters(
|
||||
data, shallow_search=request.shallow_search,
|
||||
data,
|
||||
shallow_search=request.shallow_search,
|
||||
)
|
||||
selected_project_ids = None
|
||||
if request.active_users or request.children_type:
|
||||
@ -246,7 +247,9 @@ def get_all_ex(call: APICall, company_id: str, request: ProjectsGetRequest):
|
||||
|
||||
if request.include_dataset_stats:
|
||||
dataset_stats = project_bll.get_dataset_stats(
|
||||
company=company_id, project_ids=project_ids, users=request.active_users,
|
||||
company=company_id,
|
||||
project_ids=project_ids,
|
||||
users=request.active_users,
|
||||
)
|
||||
for project in projects:
|
||||
project["dataset_stats"] = dataset_stats.get(project["id"])
|
||||
@ -255,15 +258,16 @@ def get_all_ex(call: APICall, company_id: str, request: ProjectsGetRequest):
|
||||
|
||||
|
||||
@endpoint("projects.get_all")
|
||||
def get_all(call: APICall):
|
||||
def get_all(call: APICall, company: str, _):
|
||||
data = call.data
|
||||
conform_tag_fields(call, data)
|
||||
_adjust_search_parameters(
|
||||
data, shallow_search=data.get("shallow_search", False),
|
||||
data,
|
||||
shallow_search=data.get("shallow_search", False),
|
||||
)
|
||||
ret_params = {}
|
||||
projects = Project.get_many(
|
||||
company=call.identity.company,
|
||||
company=company,
|
||||
query_dict=data,
|
||||
query=_hidden_query(
|
||||
search_hidden=data.get("search_hidden"), ids=data.get("id")
|
||||
@ -277,9 +281,11 @@ def get_all(call: APICall):
|
||||
|
||||
|
||||
@endpoint(
|
||||
"projects.create", required_fields=["name"], response_data_model=IdResponse,
|
||||
"projects.create",
|
||||
required_fields=["name"],
|
||||
response_data_model=IdResponse,
|
||||
)
|
||||
def create(call: APICall):
|
||||
def create(call: APICall, company: str, _):
|
||||
identity = call.identity
|
||||
|
||||
with translate_errors_context():
|
||||
@ -288,15 +294,15 @@ def create(call: APICall):
|
||||
|
||||
return IdResponse(
|
||||
id=ProjectBLL.create(
|
||||
user=identity.user, company=identity.company, **fields,
|
||||
user=identity.user,
|
||||
company=company,
|
||||
**fields,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@endpoint(
|
||||
"projects.update", required_fields=["project"], response_data_model=UpdateResponse
|
||||
)
|
||||
def update(call: APICall):
|
||||
@endpoint("projects.update", response_data_model=UpdateResponse)
|
||||
def update(call: APICall, company: str, request: ProjectRequest):
|
||||
"""
|
||||
update
|
||||
|
||||
@ -309,9 +315,7 @@ def update(call: APICall):
|
||||
call.data, create_fields, Project.get_fields(), discard_none_values=False
|
||||
)
|
||||
conform_tag_fields(call, fields, validate=True)
|
||||
updated = ProjectBLL.update(
|
||||
company=call.identity.company, project_id=call.data["project"], **fields
|
||||
)
|
||||
updated = ProjectBLL.update(company=company, project_id=request.project, **fields)
|
||||
conform_output_tags(call, fields)
|
||||
call.result.data_model = UpdateResponse(updated=updated, fields=fields)
|
||||
|
||||
@ -375,7 +379,6 @@ def delete(call: APICall, company_id: str, request: DeleteRequest):
|
||||
def get_unique_metric_variants(
|
||||
call: APICall, company_id: str, request: GetUniqueMetricsRequest
|
||||
):
|
||||
|
||||
metrics = project_queries.get_unique_metric_variants(
|
||||
company_id,
|
||||
[request.project] if request.project else None,
|
||||
@ -429,7 +432,6 @@ def get_model_metadata_values(
|
||||
request_data_model=GetParamsRequest,
|
||||
)
|
||||
def get_hyper_parameters(call: APICall, company_id: str, request: GetParamsRequest):
|
||||
|
||||
total, remaining, parameters = project_queries.get_aggregated_project_parameters(
|
||||
company_id,
|
||||
project_ids=[request.project] if request.project else None,
|
||||
|
@ -3,7 +3,11 @@ from datetime import datetime
|
||||
from pyhocon.config_tree import NoneValue
|
||||
|
||||
from apiserver.apierrors import errors
|
||||
from apiserver.apimodels.server import ReportStatsOptionRequest, ReportStatsOptionResponse
|
||||
from apiserver.apimodels.server import (
|
||||
ReportStatsOptionRequest,
|
||||
ReportStatsOptionResponse,
|
||||
GetConfigRequest,
|
||||
)
|
||||
from apiserver.bll.statistics.stats_reporter import StatisticsReporter
|
||||
from apiserver.config_repo import config
|
||||
from apiserver.config.info import get_version, get_build_number, get_commit_number
|
||||
@ -22,8 +26,8 @@ def get_stats(call: APICall):
|
||||
|
||||
|
||||
@endpoint("server.config")
|
||||
def get_config(call: APICall):
|
||||
path = call.data.get("path")
|
||||
def get_config(call: APICall, _, request: GetConfigRequest):
|
||||
path = request.path
|
||||
if path:
|
||||
c = dict(config.get(path))
|
||||
else:
|
||||
|
@ -7,7 +7,11 @@ from mongoengine import Q
|
||||
|
||||
from apiserver.apierrors import errors
|
||||
from apiserver.apimodels.base import UpdateResponse
|
||||
from apiserver.apimodels.users import CreateRequest, SetPreferencesRequest
|
||||
from apiserver.apimodels.users import (
|
||||
CreateRequest,
|
||||
SetPreferencesRequest,
|
||||
UserRequest,
|
||||
)
|
||||
from apiserver.bll.project import ProjectBLL
|
||||
from apiserver.bll.user import UserBLL
|
||||
from apiserver.config_repo import config
|
||||
@ -48,13 +52,13 @@ def get_user(call, company_id, user_id, only=None):
|
||||
return res.to_proper_dict()
|
||||
|
||||
|
||||
@endpoint("users.get_by_id", required_fields=["user"])
|
||||
def get_by_id(call: APICall, company_id, _):
|
||||
user_id = call.data["user"]
|
||||
@endpoint("users.get_by_id")
|
||||
def get_by_id(call: APICall, company_id, request: UserRequest):
|
||||
user_id = request.user
|
||||
call.result.data = {"user": get_user(call, company_id, user_id)}
|
||||
|
||||
|
||||
@endpoint("users.get_all_ex", required_fields=[])
|
||||
@endpoint("users.get_all_ex")
|
||||
def get_all_ex(call: APICall, company_id, _):
|
||||
with translate_errors_context("retrieving users"):
|
||||
res = User.get_many_with_join(company=company_id, query_dict=call.data)
|
||||
@ -62,7 +66,7 @@ def get_all_ex(call: APICall, company_id, _):
|
||||
call.result.data = {"users": res}
|
||||
|
||||
|
||||
@endpoint("users.get_all_ex", min_version="2.8", required_fields=[])
|
||||
@endpoint("users.get_all_ex", min_version="2.8")
|
||||
def get_all_ex2_8(call: APICall, company_id, _):
|
||||
with translate_errors_context("retrieving users"):
|
||||
data = call.data
|
||||
@ -83,7 +87,7 @@ def get_all_ex2_8(call: APICall, company_id, _):
|
||||
call.result.data = {"users": res}
|
||||
|
||||
|
||||
@endpoint("users.get_all", required_fields=[])
|
||||
@endpoint("users.get_all")
|
||||
def get_all(call: APICall, company_id, _):
|
||||
with translate_errors_context("retrieving users"):
|
||||
res = User.get_many(
|
||||
@ -138,9 +142,9 @@ def create(call: APICall):
|
||||
UserBLL.create(call.data_model)
|
||||
|
||||
|
||||
@endpoint("users.delete", required_fields=["user"])
|
||||
def delete(call: APICall):
|
||||
UserBLL.delete(call.data["user"])
|
||||
@endpoint("users.delete")
|
||||
def delete(_: APICall, __, request: UserRequest):
|
||||
UserBLL.delete(request.user)
|
||||
|
||||
|
||||
def update_user(user_id, company_id, data: dict) -> Tuple[int, dict]:
|
||||
@ -159,9 +163,9 @@ def update_user(user_id, company_id, data: dict) -> Tuple[int, dict]:
|
||||
return User.safe_update(company_id, user_id, partial_update_dict)
|
||||
|
||||
|
||||
@endpoint("users.update", required_fields=["user"], response_data_model=UpdateResponse)
|
||||
def update(call, company_id, _):
|
||||
user_id = call.data["user"]
|
||||
@endpoint("users.update", response_data_model=UpdateResponse)
|
||||
def update(call, company_id, request: UserRequest):
|
||||
user_id = request.user
|
||||
update_count, updated_fields = update_user(user_id, company_id, call.data)
|
||||
call.result.data_model = UpdateResponse(updated=update_count, fields=updated_fields)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user