mirror of
https://github.com/clearml/clearml-server
synced 2025-06-14 19:58:05 +00:00
Add input parameters check to multiple APIs
This commit is contained in:
parent
702b6dc9c8
commit
a47e65d974
@ -13,6 +13,14 @@ from apiserver.config_repo import config
|
|||||||
from apiserver.utilities.stringenum import StringEnum
|
from apiserver.utilities.stringenum import StringEnum
|
||||||
|
|
||||||
|
|
||||||
|
class TaskRequest(Base):
|
||||||
|
task: str = StringField(required=True)
|
||||||
|
|
||||||
|
|
||||||
|
class ModelRequest(Base):
|
||||||
|
model: str = StringField(required=True)
|
||||||
|
|
||||||
|
|
||||||
class HistogramRequestBase(Base):
|
class HistogramRequestBase(Base):
|
||||||
samples: int = IntField(default=2000, validators=[Min(1), Max(6000)])
|
samples: int = IntField(default=2000, validators=[Min(1), Max(6000)])
|
||||||
key: ScalarKeyEnum = ActualEnumField(ScalarKeyEnum, default=ScalarKeyEnum.iter)
|
key: ScalarKeyEnum = ActualEnumField(ScalarKeyEnum, default=ScalarKeyEnum.iter)
|
||||||
@ -29,6 +37,11 @@ class ScalarMetricsIterHistogramRequest(HistogramRequestBase):
|
|||||||
model_events: bool = BoolField(default=False)
|
model_events: bool = BoolField(default=False)
|
||||||
|
|
||||||
|
|
||||||
|
class GetMetricsAndVariantsRequest(Base):
|
||||||
|
task: str = StringField(required=True)
|
||||||
|
model_events: bool = BoolField(default=False)
|
||||||
|
|
||||||
|
|
||||||
class MultiTaskScalarMetricsIterHistogramRequest(HistogramRequestBase):
|
class MultiTaskScalarMetricsIterHistogramRequest(HistogramRequestBase):
|
||||||
tasks: Sequence[str] = ListField(
|
tasks: Sequence[str] = ListField(
|
||||||
items_types=str,
|
items_types=str,
|
||||||
@ -51,6 +64,12 @@ class TaskMetric(Base):
|
|||||||
variants: Sequence[str] = ListField(items_types=str)
|
variants: Sequence[str] = ListField(items_types=str)
|
||||||
|
|
||||||
|
|
||||||
|
class LegacyMetricEventsRequest(TaskRequest):
|
||||||
|
iters: int = IntField(default=1, validators=validators.Min(1))
|
||||||
|
scroll_id: str = StringField()
|
||||||
|
model_events: bool = BoolField(default=False)
|
||||||
|
|
||||||
|
|
||||||
class MetricEventsRequest(Base):
|
class MetricEventsRequest(Base):
|
||||||
metrics: Sequence[TaskMetric] = ListField(
|
metrics: Sequence[TaskMetric] = ListField(
|
||||||
items_types=TaskMetric, validators=[Length(minimum_value=1)]
|
items_types=TaskMetric, validators=[Length(minimum_value=1)]
|
||||||
@ -59,7 +78,14 @@ class MetricEventsRequest(Base):
|
|||||||
navigate_earlier: bool = BoolField(default=True)
|
navigate_earlier: bool = BoolField(default=True)
|
||||||
refresh: bool = BoolField(default=False)
|
refresh: bool = BoolField(default=False)
|
||||||
scroll_id: str = StringField()
|
scroll_id: str = StringField()
|
||||||
model_events: bool = BoolField()
|
model_events: bool = BoolField(default=False)
|
||||||
|
|
||||||
|
|
||||||
|
class VectorMetricsIterHistogramRequest(Base):
|
||||||
|
task: str = StringField(required=True)
|
||||||
|
metric: str = StringField(required=True)
|
||||||
|
variant: str = StringField(required=True)
|
||||||
|
model_events: bool = BoolField(default=False)
|
||||||
|
|
||||||
|
|
||||||
class GetVariantSampleRequest(Base):
|
class GetVariantSampleRequest(Base):
|
||||||
@ -110,6 +136,11 @@ class TaskEventsRequest(TaskEventsRequestBase):
|
|||||||
model_events: bool = BoolField(default=False)
|
model_events: bool = BoolField(default=False)
|
||||||
|
|
||||||
|
|
||||||
|
class LegacyLogEventsRequest(TaskEventsRequestBase):
|
||||||
|
order: Optional[str] = ActualEnumField(LogOrderEnum, default=LogOrderEnum.desc)
|
||||||
|
scroll_id: str = StringField()
|
||||||
|
|
||||||
|
|
||||||
class LogEventsRequest(TaskEventsRequestBase):
|
class LogEventsRequest(TaskEventsRequestBase):
|
||||||
batch_size: int = IntField(default=5000)
|
batch_size: int = IntField(default=5000)
|
||||||
navigate_earlier: bool = BoolField(default=True)
|
navigate_earlier: bool = BoolField(default=True)
|
||||||
@ -160,6 +191,11 @@ class MultiTaskMetricsRequest(MultiTasksRequestBase):
|
|||||||
event_type: EventType = ActualEnumField(EventType, default=EventType.all)
|
event_type: EventType = ActualEnumField(EventType, default=EventType.all)
|
||||||
|
|
||||||
|
|
||||||
|
class LegacyMultiTaskEventsRequest(MultiTasksRequestBase):
|
||||||
|
iters: int = IntField(default=1, validators=validators.Min(1))
|
||||||
|
scroll_id: str = StringField()
|
||||||
|
|
||||||
|
|
||||||
class MultiTaskPlotsRequest(MultiTasksRequestBase):
|
class MultiTaskPlotsRequest(MultiTasksRequestBase):
|
||||||
iters: int = IntField(default=1)
|
iters: int = IntField(default=1)
|
||||||
scroll_id: str = StringField()
|
scroll_id: str = StringField()
|
||||||
@ -177,6 +213,14 @@ class TaskPlotsRequest(Base):
|
|||||||
model_events: bool = BoolField(default=False)
|
model_events: bool = BoolField(default=False)
|
||||||
|
|
||||||
|
|
||||||
|
class GetScalarMetricDataRequest(Base):
|
||||||
|
task: str = StringField(required=True)
|
||||||
|
metric: str = StringField(required=True)
|
||||||
|
scroll_id: str = StringField()
|
||||||
|
no_scroll: bool = BoolField(default=False)
|
||||||
|
model_events: bool = BoolField(default=False)
|
||||||
|
|
||||||
|
|
||||||
class ClearScrollRequest(Base):
|
class ClearScrollRequest(Base):
|
||||||
scroll_id: str = StringField()
|
scroll_id: str = StringField()
|
||||||
|
|
||||||
|
@ -42,6 +42,21 @@ class ModelRequest(models.Base):
|
|||||||
model = fields.StringField(required=True)
|
model = fields.StringField(required=True)
|
||||||
|
|
||||||
|
|
||||||
|
class TaskRequest(models.Base):
|
||||||
|
task = fields.StringField(required=True)
|
||||||
|
|
||||||
|
|
||||||
|
class UpdateForTaskRequest(TaskRequest):
|
||||||
|
uri = fields.StringField()
|
||||||
|
iteration = fields.IntField()
|
||||||
|
override_model_id = fields.StringField()
|
||||||
|
|
||||||
|
|
||||||
|
class UpdateModelRequest(ModelRequest):
|
||||||
|
task = fields.StringField()
|
||||||
|
iteration = fields.IntField()
|
||||||
|
|
||||||
|
|
||||||
class DeleteModelRequest(ModelRequest):
|
class DeleteModelRequest(ModelRequest):
|
||||||
force = fields.BoolField(default=False)
|
force = fields.BoolField(default=False)
|
||||||
delete_external_artifacts = fields.BoolField(default=True)
|
delete_external_artifacts = fields.BoolField(default=True)
|
||||||
|
@ -6,6 +6,10 @@ class ReportStatsOptionRequest(Base):
|
|||||||
enabled = BoolField(default=None, nullable=True)
|
enabled = BoolField(default=None, nullable=True)
|
||||||
|
|
||||||
|
|
||||||
|
class GetConfigRequest(Base):
|
||||||
|
path = StringField()
|
||||||
|
|
||||||
|
|
||||||
class ReportStatsOptionResponse(Base):
|
class ReportStatsOptionResponse(Base):
|
||||||
supported = BoolField(default=True)
|
supported = BoolField(default=True)
|
||||||
enabled = BoolField()
|
enabled = BoolField()
|
||||||
|
@ -4,6 +4,10 @@ from jsonmodels.models import Base
|
|||||||
from apiserver.apimodels import DictField
|
from apiserver.apimodels import DictField
|
||||||
|
|
||||||
|
|
||||||
|
class UserRequest(Base):
|
||||||
|
user = StringField(required=True)
|
||||||
|
|
||||||
|
|
||||||
class CreateRequest(Base):
|
class CreateRequest(Base):
|
||||||
id = StringField(required=True)
|
id = StringField(required=True)
|
||||||
name = StringField(required=True)
|
name = StringField(required=True)
|
||||||
|
@ -32,6 +32,14 @@ from apiserver.apimodels.events import (
|
|||||||
TaskMetric,
|
TaskMetric,
|
||||||
MultiTaskPlotsRequest,
|
MultiTaskPlotsRequest,
|
||||||
MultiTaskMetricsRequest,
|
MultiTaskMetricsRequest,
|
||||||
|
LegacyLogEventsRequest,
|
||||||
|
TaskRequest,
|
||||||
|
GetMetricsAndVariantsRequest,
|
||||||
|
ModelRequest,
|
||||||
|
LegacyMetricEventsRequest,
|
||||||
|
GetScalarMetricDataRequest,
|
||||||
|
VectorMetricsIterHistogramRequest,
|
||||||
|
LegacyMultiTaskEventsRequest,
|
||||||
)
|
)
|
||||||
from apiserver.bll.event import EventBLL
|
from apiserver.bll.event import EventBLL
|
||||||
from apiserver.bll.event.event_common import EventType, MetricVariants, TaskCompanies
|
from apiserver.bll.event.event_common import EventType, MetricVariants, TaskCompanies
|
||||||
@ -97,15 +105,15 @@ def add_batch(call: APICall, company_id, _):
|
|||||||
call.result.data = dict(added=added, errors=err_count, errors_info=err_info)
|
call.result.data = dict(added=added, errors=err_count, errors_info=err_info)
|
||||||
|
|
||||||
|
|
||||||
@endpoint("events.get_task_log", required_fields=["task"])
|
@endpoint("events.get_task_log")
|
||||||
def get_task_log_v1_5(call, company_id, _):
|
def get_task_log_v1_5(call, company_id, request: LegacyLogEventsRequest):
|
||||||
task_id = call.data["task"]
|
task_id = request.task
|
||||||
task = task_bll.assert_exists(
|
task = task_bll.assert_exists(
|
||||||
company_id, task_id, allow_public=True, only=("company", "company_origin")
|
company_id, task_id, allow_public=True, only=("company", "company_origin")
|
||||||
)[0]
|
)[0]
|
||||||
order = call.data.get("order") or "desc"
|
order = request.order
|
||||||
scroll_id = call.data.get("scroll_id")
|
scroll_id = request.scroll_id
|
||||||
batch_size = int(call.data.get("batch_size") or 500)
|
batch_size = request.batch_size
|
||||||
events, scroll_id, total_events = event_bll.scroll_task_events(
|
events, scroll_id, total_events = event_bll.scroll_task_events(
|
||||||
task.get_index_company(),
|
task.get_index_company(),
|
||||||
task_id,
|
task_id,
|
||||||
@ -119,17 +127,17 @@ def get_task_log_v1_5(call, company_id, _):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@endpoint("events.get_task_log", min_version="1.7", required_fields=["task"])
|
@endpoint("events.get_task_log", min_version="1.7")
|
||||||
def get_task_log_v1_7(call, company_id, _):
|
def get_task_log_v1_7(call, company_id, request: LegacyLogEventsRequest):
|
||||||
task_id = call.data["task"]
|
task_id = request.task
|
||||||
task = task_bll.assert_exists(
|
task = task_bll.assert_exists(
|
||||||
company_id, task_id, allow_public=True, only=("company", "company_origin")
|
company_id, task_id, allow_public=True, only=("company", "company_origin")
|
||||||
)[0]
|
)[0]
|
||||||
|
|
||||||
order = call.data.get("order") or "desc"
|
order = request.order
|
||||||
from_ = call.data.get("from") or "head"
|
from_ = call.data.get("from") or "head"
|
||||||
scroll_id = call.data.get("scroll_id")
|
scroll_id = request.scroll_id
|
||||||
batch_size = int(call.data.get("batch_size") or 500)
|
batch_size = request.batch_size
|
||||||
|
|
||||||
scroll_order = "asc" if (from_ == "head") else "desc"
|
scroll_order = "asc" if (from_ == "head") else "desc"
|
||||||
|
|
||||||
@ -177,9 +185,9 @@ def get_task_log(call, company_id, request: LogEventsRequest):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@endpoint("events.download_task_log", required_fields=["task"])
|
@endpoint("events.download_task_log")
|
||||||
def download_task_log(call, company_id, _):
|
def download_task_log(call, company_id, request: TaskRequest):
|
||||||
task_id = call.data["task"]
|
task_id = request.task
|
||||||
task = task_bll.assert_exists(
|
task = task_bll.assert_exists(
|
||||||
company_id, task_id, allow_public=True, only=("company", "company_origin")
|
company_id, task_id, allow_public=True, only=("company", "company_origin")
|
||||||
)[0]
|
)[0]
|
||||||
@ -257,10 +265,12 @@ def download_task_log(call, company_id, _):
|
|||||||
call.result.raw_data = generate()
|
call.result.raw_data = generate()
|
||||||
|
|
||||||
|
|
||||||
@endpoint("events.get_vector_metrics_and_variants", required_fields=["task"])
|
@endpoint("events.get_vector_metrics_and_variants")
|
||||||
def get_vector_metrics_and_variants(call, company_id, _):
|
def get_vector_metrics_and_variants(
|
||||||
task_id = call.data["task"]
|
call, company_id, request: GetMetricsAndVariantsRequest
|
||||||
model_events = call.data["model_events"]
|
):
|
||||||
|
task_id = request.task
|
||||||
|
model_events = request.model_events
|
||||||
task_or_model = _assert_task_or_model_exists(
|
task_or_model = _assert_task_or_model_exists(
|
||||||
company_id,
|
company_id,
|
||||||
task_id,
|
task_id,
|
||||||
@ -273,10 +283,12 @@ def get_vector_metrics_and_variants(call, company_id, _):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@endpoint("events.get_scalar_metrics_and_variants", required_fields=["task"])
|
@endpoint("events.get_scalar_metrics_and_variants")
|
||||||
def get_scalar_metrics_and_variants(call, company_id, _):
|
def get_scalar_metrics_and_variants(
|
||||||
task_id = call.data["task"]
|
call, company_id, request: GetMetricsAndVariantsRequest
|
||||||
model_events = call.data["model_events"]
|
):
|
||||||
|
task_id = request.task
|
||||||
|
model_events = request.model_events
|
||||||
task_or_model = _assert_task_or_model_exists(
|
task_or_model = _assert_task_or_model_exists(
|
||||||
company_id,
|
company_id,
|
||||||
task_id,
|
task_id,
|
||||||
@ -292,18 +304,19 @@ def get_scalar_metrics_and_variants(call, company_id, _):
|
|||||||
# todo: !!! currently returning 10,000 records. should decide on a better way to control it
|
# todo: !!! currently returning 10,000 records. should decide on a better way to control it
|
||||||
@endpoint(
|
@endpoint(
|
||||||
"events.vector_metrics_iter_histogram",
|
"events.vector_metrics_iter_histogram",
|
||||||
required_fields=["task", "metric", "variant"],
|
|
||||||
)
|
)
|
||||||
def vector_metrics_iter_histogram(call, company_id, _):
|
def vector_metrics_iter_histogram(
|
||||||
task_id = call.data["task"]
|
call, company_id, request: VectorMetricsIterHistogramRequest
|
||||||
model_events = call.data["model_events"]
|
):
|
||||||
|
task_id = request.task
|
||||||
|
model_events = request.model_events
|
||||||
task_or_model = _assert_task_or_model_exists(
|
task_or_model = _assert_task_or_model_exists(
|
||||||
company_id,
|
company_id,
|
||||||
task_id,
|
task_id,
|
||||||
model_events=model_events,
|
model_events=model_events,
|
||||||
)[0]
|
)[0]
|
||||||
metric = call.data["metric"]
|
metric = request.metric
|
||||||
variant = call.data["variant"]
|
variant = request.variant
|
||||||
iterations, vectors = event_bll.get_vector_metrics_per_iter(
|
iterations, vectors = event_bll.get_vector_metrics_per_iter(
|
||||||
task_or_model.get_index_company(), task_id, metric, variant
|
task_or_model.get_index_company(), task_id, metric, variant
|
||||||
)
|
)
|
||||||
@ -404,13 +417,13 @@ def get_task_events(_, company_id, request: TaskEventsRequest):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@endpoint("events.get_scalar_metric_data", required_fields=["task", "metric"])
|
@endpoint("events.get_scalar_metric_data")
|
||||||
def get_scalar_metric_data(call, company_id, _):
|
def get_scalar_metric_data(call, company_id, request: GetScalarMetricDataRequest):
|
||||||
task_id = call.data["task"]
|
task_id = request.task
|
||||||
metric = call.data["metric"]
|
metric = request.metric
|
||||||
scroll_id = call.data.get("scroll_id")
|
scroll_id = request.scroll_id
|
||||||
no_scroll = call.data.get("no_scroll", False)
|
no_scroll = request.no_scroll
|
||||||
model_events = call.data.get("model_events", False)
|
model_events = request.model_events
|
||||||
|
|
||||||
task_or_model = _assert_task_or_model_exists(
|
task_or_model = _assert_task_or_model_exists(
|
||||||
company_id,
|
company_id,
|
||||||
@ -435,9 +448,9 @@ def get_scalar_metric_data(call, company_id, _):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@endpoint("events.get_task_latest_scalar_values", required_fields=["task"])
|
@endpoint("events.get_task_latest_scalar_values")
|
||||||
def get_task_latest_scalar_values(call, company_id, _):
|
def get_task_latest_scalar_values(call, company_id, request: TaskRequest):
|
||||||
task_id = call.data["task"]
|
task_id = request.task
|
||||||
task = task_bll.assert_exists(
|
task = task_bll.assert_exists(
|
||||||
company_id, task_id, allow_public=True, only=("company", "company_origin")
|
company_id, task_id, allow_public=True, only=("company", "company_origin")
|
||||||
)[0]
|
)[0]
|
||||||
@ -558,11 +571,11 @@ def get_task_single_value_metrics(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@endpoint("events.get_multi_task_plots", required_fields=["tasks"])
|
@endpoint("events.get_multi_task_plots")
|
||||||
def get_multi_task_plots_v1_7(call, company_id, _):
|
def get_multi_task_plots_v1_7(call, company_id, request: LegacyMultiTaskEventsRequest):
|
||||||
task_ids = call.data["tasks"]
|
task_ids = request.tasks
|
||||||
iters = call.data.get("iters", 1)
|
iters = request.iters
|
||||||
scroll_id = call.data.get("scroll_id")
|
scroll_id = request.scroll_id
|
||||||
|
|
||||||
companies = _get_task_or_model_index_companies(company_id, task_ids)
|
companies = _get_task_or_model_index_companies(company_id, task_ids)
|
||||||
|
|
||||||
@ -644,11 +657,11 @@ def get_multi_task_plots(call, company_id, request: MultiTaskPlotsRequest):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@endpoint("events.get_task_plots", required_fields=["task"])
|
@endpoint("events.get_task_plots")
|
||||||
def get_task_plots_v1_7(call, company_id, _):
|
def get_task_plots_v1_7(call, company_id, request: LegacyMetricEventsRequest):
|
||||||
task_id = call.data["task"]
|
task_id = request.task
|
||||||
iters = call.data.get("iters", 1)
|
iters = request.iters
|
||||||
scroll_id = call.data.get("scroll_id")
|
scroll_id = request.scroll_id
|
||||||
|
|
||||||
task = task_bll.assert_exists(
|
task = task_bll.assert_exists(
|
||||||
company_id, task_id, allow_public=True, only=("company", "company_origin")
|
company_id, task_id, allow_public=True, only=("company", "company_origin")
|
||||||
@ -766,11 +779,11 @@ def task_plots(call, company_id, request: MetricEventsRequest):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@endpoint("events.debug_images", required_fields=["task"])
|
@endpoint("events.debug_images")
|
||||||
def get_debug_images_v1_7(call, company_id, _):
|
def get_debug_images_v1_7(call, company_id, request: LegacyMetricEventsRequest):
|
||||||
task_id = call.data["task"]
|
task_id = request.task
|
||||||
iters = call.data.get("iters") or 1
|
iters = request.iters
|
||||||
scroll_id = call.data.get("scroll_id")
|
scroll_id = request.scroll_id
|
||||||
|
|
||||||
task = task_bll.assert_exists(
|
task = task_bll.assert_exists(
|
||||||
company_id, task_id, allow_public=True, only=("company", "company_origin")
|
company_id, task_id, allow_public=True, only=("company", "company_origin")
|
||||||
@ -803,12 +816,12 @@ def get_debug_images_v1_7(call, company_id, _):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@endpoint("events.debug_images", min_version="1.8", required_fields=["task"])
|
@endpoint("events.debug_images", min_version="1.8")
|
||||||
def get_debug_images_v1_8(call, company_id, _):
|
def get_debug_images_v1_8(call, company_id, request: LegacyMetricEventsRequest):
|
||||||
task_id = call.data["task"]
|
task_id = request.task
|
||||||
iters = call.data.get("iters") or 1
|
iters = request.iters
|
||||||
scroll_id = call.data.get("scroll_id")
|
scroll_id = request.scroll_id
|
||||||
model_events = call.data.get("model_events", False)
|
model_events = request.model_events
|
||||||
|
|
||||||
tasks_or_model = _assert_task_or_model_exists(
|
tasks_or_model = _assert_task_or_model_exists(
|
||||||
company_id,
|
company_id,
|
||||||
@ -975,8 +988,7 @@ def get_multi_task_metrics(call: APICall, company_id, request: MultiTaskMetricsR
|
|||||||
return {"metrics": []}
|
return {"metrics": []}
|
||||||
|
|
||||||
metrics = event_bll.metrics.get_multi_task_metrics(
|
metrics = event_bll.metrics.get_multi_task_metrics(
|
||||||
companies=companies,
|
companies=companies, event_type=request.event_type
|
||||||
event_type=request.event_type
|
|
||||||
)
|
)
|
||||||
res = [
|
res = [
|
||||||
{
|
{
|
||||||
@ -985,14 +997,12 @@ def get_multi_task_metrics(call: APICall, company_id, request: MultiTaskMetricsR
|
|||||||
}
|
}
|
||||||
for m, vars_ in metrics.items()
|
for m, vars_ in metrics.items()
|
||||||
]
|
]
|
||||||
call.result.data = {
|
call.result.data = {"metrics": sorted(res, key=itemgetter("metric"))}
|
||||||
"metrics": sorted(res, key=itemgetter("metric"))
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@endpoint("events.delete_for_task", required_fields=["task"])
|
@endpoint("events.delete_for_task")
|
||||||
def delete_for_task(call, company_id, _):
|
def delete_for_task(call, company_id, request: TaskRequest):
|
||||||
task_id = call.data["task"]
|
task_id = request.task
|
||||||
allow_locked = call.data.get("allow_locked", False)
|
allow_locked = call.data.get("allow_locked", False)
|
||||||
|
|
||||||
get_task_with_write_access(
|
get_task_with_write_access(
|
||||||
@ -1005,9 +1015,9 @@ def delete_for_task(call, company_id, _):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@endpoint("events.delete_for_model", required_fields=["model"])
|
@endpoint("events.delete_for_model")
|
||||||
def delete_for_model(call: APICall, company_id: str, _):
|
def delete_for_model(call: APICall, company_id: str, request: ModelRequest):
|
||||||
model_id = call.data["model"]
|
model_id = request.model
|
||||||
allow_locked = call.data.get("allow_locked", False)
|
allow_locked = call.data.get("allow_locked", False)
|
||||||
|
|
||||||
model_bll.assert_exists(company_id, model_id, return_models=False)
|
model_bll.assert_exists(company_id, model_id, return_models=False)
|
||||||
|
@ -21,6 +21,10 @@ from apiserver.apimodels.models import (
|
|||||||
ModelsPublishManyRequest,
|
ModelsPublishManyRequest,
|
||||||
ModelsDeleteManyRequest,
|
ModelsDeleteManyRequest,
|
||||||
ModelsGetRequest,
|
ModelsGetRequest,
|
||||||
|
ModelRequest,
|
||||||
|
TaskRequest,
|
||||||
|
UpdateForTaskRequest,
|
||||||
|
UpdateModelRequest,
|
||||||
)
|
)
|
||||||
from apiserver.apimodels.tasks import UpdateTagsRequest
|
from apiserver.apimodels.tasks import UpdateTagsRequest
|
||||||
from apiserver.bll.model import ModelBLL, Metadata
|
from apiserver.bll.model import ModelBLL, Metadata
|
||||||
@ -67,9 +71,9 @@ def conform_model_data(call: APICall, model_data: Union[Sequence[dict], dict]):
|
|||||||
unescape_metadata(call, model_data)
|
unescape_metadata(call, model_data)
|
||||||
|
|
||||||
|
|
||||||
@endpoint("models.get_by_id", required_fields=["model"])
|
@endpoint("models.get_by_id")
|
||||||
def get_by_id(call: APICall, company_id, _):
|
def get_by_id(call: APICall, company_id, request: ModelRequest):
|
||||||
model_id = call.data["model"]
|
model_id = request.model
|
||||||
call_data = Metadata.escape_query_parameters(call.data)
|
call_data = Metadata.escape_query_parameters(call.data)
|
||||||
models = Model.get_many(
|
models = Model.get_many(
|
||||||
company=company_id,
|
company=company_id,
|
||||||
@ -87,12 +91,12 @@ def get_by_id(call: APICall, company_id, _):
|
|||||||
call.result.data = {"model": models[0]}
|
call.result.data = {"model": models[0]}
|
||||||
|
|
||||||
|
|
||||||
@endpoint("models.get_by_task_id", required_fields=["task"])
|
@endpoint("models.get_by_task_id")
|
||||||
def get_by_task_id(call: APICall, company_id, _):
|
def get_by_task_id(call: APICall, company_id, request: TaskRequest):
|
||||||
if call.requested_endpoint_version > ModelsBackwardsCompatibility.max_version:
|
if call.requested_endpoint_version > ModelsBackwardsCompatibility.max_version:
|
||||||
raise errors.moved_permanently.NotSupported("use models.get_by_id/get_all apis")
|
raise errors.moved_permanently.NotSupported("use models.get_by_id/get_all apis")
|
||||||
|
|
||||||
task_id = call.data["task"]
|
task_id = request.task
|
||||||
|
|
||||||
query = dict(id=task_id, company=company_id)
|
query = dict(id=task_id, company=company_id)
|
||||||
task = Task.get(_only=["models"], **query)
|
task = Task.get(_only=["models"], **query)
|
||||||
@ -157,7 +161,7 @@ def get_by_id_ex(call: APICall, company_id, _):
|
|||||||
call.result.data = {"models": models}
|
call.result.data = {"models": models}
|
||||||
|
|
||||||
|
|
||||||
@endpoint("models.get_all", required_fields=[])
|
@endpoint("models.get_all")
|
||||||
def get_all(call: APICall, company_id, _):
|
def get_all(call: APICall, company_id, _):
|
||||||
conform_tag_fields(call, call.data)
|
conform_tag_fields(call, call.data)
|
||||||
call_data = Metadata.escape_query_parameters(call.data)
|
call_data = Metadata.escape_query_parameters(call.data)
|
||||||
@ -236,15 +240,15 @@ def _reset_cached_tags(company: str, projects: Sequence[str]):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@endpoint("models.update_for_task", required_fields=["task"])
|
@endpoint("models.update_for_task")
|
||||||
def update_for_task(call: APICall, company_id, _):
|
def update_for_task(call: APICall, company_id, request: UpdateForTaskRequest):
|
||||||
if call.requested_endpoint_version > ModelsBackwardsCompatibility.max_version:
|
if call.requested_endpoint_version > ModelsBackwardsCompatibility.max_version:
|
||||||
raise errors.moved_permanently.NotSupported("use tasks.add_or_update_model")
|
raise errors.moved_permanently.NotSupported("use tasks.add_or_update_model")
|
||||||
|
|
||||||
task_id = call.data["task"]
|
task_id = request.task
|
||||||
uri = call.data.get("uri")
|
uri = request.uri
|
||||||
iteration = call.data.get("iteration")
|
iteration = request.iteration
|
||||||
override_model_id = call.data.get("override_model_id")
|
override_model_id = request.override_model_id
|
||||||
if not (uri or override_model_id) or (uri and override_model_id):
|
if not (uri or override_model_id) or (uri and override_model_id):
|
||||||
raise errors.bad_request.MissingRequiredFields(
|
raise errors.bad_request.MissingRequiredFields(
|
||||||
"exactly one field is required", fields=("uri", "override_model_id")
|
"exactly one field is required", fields=("uri", "override_model_id")
|
||||||
@ -411,9 +415,9 @@ def validate_task(company_id: str, identity: Identity, fields: dict):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@endpoint("models.edit", required_fields=["model"], response_data_model=UpdateResponse)
|
@endpoint("models.edit", response_data_model=UpdateResponse)
|
||||||
def edit(call: APICall, company_id, _):
|
def edit(call: APICall, company_id, request: UpdateModelRequest):
|
||||||
model_id = call.data["model"]
|
model_id = request.model
|
||||||
|
|
||||||
model = ModelBLL.get_company_model_by_id(company_id=company_id, model_id=model_id)
|
model = ModelBLL.get_company_model_by_id(company_id=company_id, model_id=model_id)
|
||||||
|
|
||||||
@ -428,7 +432,7 @@ def edit(call: APICall, company_id, _):
|
|||||||
d.update(value)
|
d.update(value)
|
||||||
fields[key] = d
|
fields[key] = d
|
||||||
|
|
||||||
iteration = call.data.get("iteration")
|
iteration = request.iteration
|
||||||
task_id = model.task or fields.get("task")
|
task_id = model.task or fields.get("task")
|
||||||
if task_id and iteration is not None:
|
if task_id and iteration is not None:
|
||||||
TaskBLL.update_statistics(
|
TaskBLL.update_statistics(
|
||||||
@ -460,13 +464,9 @@ def edit(call: APICall, company_id, _):
|
|||||||
call.result.data_model = UpdateResponse(updated=0)
|
call.result.data_model = UpdateResponse(updated=0)
|
||||||
|
|
||||||
|
|
||||||
def _update_model(call: APICall, company_id, model_id=None):
|
def _update_model(call: APICall, company_id, model_id):
|
||||||
model_id = model_id or call.data["model"]
|
|
||||||
|
|
||||||
model = ModelBLL.get_company_model_by_id(company_id=company_id, model_id=model_id)
|
model = ModelBLL.get_company_model_by_id(company_id=company_id, model_id=model_id)
|
||||||
|
|
||||||
data = prepare_update_fields(call, company_id, call.data)
|
data = prepare_update_fields(call, company_id, call.data)
|
||||||
|
|
||||||
task_id = data.get("task")
|
task_id = data.get("task")
|
||||||
iteration = data.get("iteration")
|
iteration = data.get("iteration")
|
||||||
if task_id and iteration is not None:
|
if task_id and iteration is not None:
|
||||||
@ -502,11 +502,9 @@ def _update_model(call: APICall, company_id, model_id=None):
|
|||||||
return UpdateResponse(updated=updated_count, fields=updated_fields)
|
return UpdateResponse(updated=updated_count, fields=updated_fields)
|
||||||
|
|
||||||
|
|
||||||
@endpoint(
|
@endpoint("models.update", response_data_model=UpdateResponse)
|
||||||
"models.update", required_fields=["model"], response_data_model=UpdateResponse
|
def update(call, company_id, request: UpdateModelRequest):
|
||||||
)
|
call.result.data_model = _update_model(call, company_id, model_id=request.model)
|
||||||
def update(call, company_id, _):
|
|
||||||
call.result.data_model = _update_model(call, company_id)
|
|
||||||
|
|
||||||
|
|
||||||
@endpoint(
|
@endpoint(
|
||||||
@ -629,7 +627,9 @@ def archive_many(call: APICall, company_id, request: BatchRequest):
|
|||||||
)
|
)
|
||||||
def unarchive_many(call: APICall, company_id, request: BatchRequest):
|
def unarchive_many(call: APICall, company_id, request: BatchRequest):
|
||||||
results, failures = run_batch_operation(
|
results, failures = run_batch_operation(
|
||||||
func=partial(ModelBLL.unarchive_model, company_id=company_id, user_id=call.identity.user),
|
func=partial(
|
||||||
|
ModelBLL.unarchive_model, company_id=company_id, user_id=call.identity.user
|
||||||
|
),
|
||||||
ids=request.ids,
|
ids=request.ids,
|
||||||
)
|
)
|
||||||
call.result.data_model = BatchResponse(
|
call.result.data_model = BatchResponse(
|
||||||
|
@ -59,13 +59,12 @@ create_fields = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@endpoint("projects.get_by_id", required_fields=["project"])
|
@endpoint("projects.get_by_id")
|
||||||
def get_by_id(call):
|
def get_by_id(call: APICall, company: str, request: ProjectRequest):
|
||||||
assert isinstance(call, APICall)
|
project_id = request.project
|
||||||
project_id = call.data["project"]
|
|
||||||
|
|
||||||
with translate_errors_context():
|
with translate_errors_context():
|
||||||
query = Q(id=project_id) & get_company_or_none_constraint(call.identity.company)
|
query = Q(id=project_id) & get_company_or_none_constraint(company)
|
||||||
project = Project.objects(query).first()
|
project = Project.objects(query).first()
|
||||||
if not project:
|
if not project:
|
||||||
raise errors.bad_request.InvalidProjectId(id=project_id)
|
raise errors.bad_request.InvalidProjectId(id=project_id)
|
||||||
@ -147,8 +146,10 @@ def get_all_ex(call: APICall, company_id: str, request: ProjectsGetRequest):
|
|||||||
requested_ids = data.get("id")
|
requested_ids = data.get("id")
|
||||||
if isinstance(requested_ids, str):
|
if isinstance(requested_ids, str):
|
||||||
requested_ids = [requested_ids]
|
requested_ids = [requested_ids]
|
||||||
|
|
||||||
_adjust_search_parameters(
|
_adjust_search_parameters(
|
||||||
data, shallow_search=request.shallow_search,
|
data,
|
||||||
|
shallow_search=request.shallow_search,
|
||||||
)
|
)
|
||||||
selected_project_ids = None
|
selected_project_ids = None
|
||||||
if request.active_users or request.children_type:
|
if request.active_users or request.children_type:
|
||||||
@ -246,7 +247,9 @@ def get_all_ex(call: APICall, company_id: str, request: ProjectsGetRequest):
|
|||||||
|
|
||||||
if request.include_dataset_stats:
|
if request.include_dataset_stats:
|
||||||
dataset_stats = project_bll.get_dataset_stats(
|
dataset_stats = project_bll.get_dataset_stats(
|
||||||
company=company_id, project_ids=project_ids, users=request.active_users,
|
company=company_id,
|
||||||
|
project_ids=project_ids,
|
||||||
|
users=request.active_users,
|
||||||
)
|
)
|
||||||
for project in projects:
|
for project in projects:
|
||||||
project["dataset_stats"] = dataset_stats.get(project["id"])
|
project["dataset_stats"] = dataset_stats.get(project["id"])
|
||||||
@ -255,15 +258,16 @@ def get_all_ex(call: APICall, company_id: str, request: ProjectsGetRequest):
|
|||||||
|
|
||||||
|
|
||||||
@endpoint("projects.get_all")
|
@endpoint("projects.get_all")
|
||||||
def get_all(call: APICall):
|
def get_all(call: APICall, company: str, _):
|
||||||
data = call.data
|
data = call.data
|
||||||
conform_tag_fields(call, data)
|
conform_tag_fields(call, data)
|
||||||
_adjust_search_parameters(
|
_adjust_search_parameters(
|
||||||
data, shallow_search=data.get("shallow_search", False),
|
data,
|
||||||
|
shallow_search=data.get("shallow_search", False),
|
||||||
)
|
)
|
||||||
ret_params = {}
|
ret_params = {}
|
||||||
projects = Project.get_many(
|
projects = Project.get_many(
|
||||||
company=call.identity.company,
|
company=company,
|
||||||
query_dict=data,
|
query_dict=data,
|
||||||
query=_hidden_query(
|
query=_hidden_query(
|
||||||
search_hidden=data.get("search_hidden"), ids=data.get("id")
|
search_hidden=data.get("search_hidden"), ids=data.get("id")
|
||||||
@ -277,9 +281,11 @@ def get_all(call: APICall):
|
|||||||
|
|
||||||
|
|
||||||
@endpoint(
|
@endpoint(
|
||||||
"projects.create", required_fields=["name"], response_data_model=IdResponse,
|
"projects.create",
|
||||||
|
required_fields=["name"],
|
||||||
|
response_data_model=IdResponse,
|
||||||
)
|
)
|
||||||
def create(call: APICall):
|
def create(call: APICall, company: str, _):
|
||||||
identity = call.identity
|
identity = call.identity
|
||||||
|
|
||||||
with translate_errors_context():
|
with translate_errors_context():
|
||||||
@ -288,15 +294,15 @@ def create(call: APICall):
|
|||||||
|
|
||||||
return IdResponse(
|
return IdResponse(
|
||||||
id=ProjectBLL.create(
|
id=ProjectBLL.create(
|
||||||
user=identity.user, company=identity.company, **fields,
|
user=identity.user,
|
||||||
|
company=company,
|
||||||
|
**fields,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@endpoint(
|
@endpoint("projects.update", response_data_model=UpdateResponse)
|
||||||
"projects.update", required_fields=["project"], response_data_model=UpdateResponse
|
def update(call: APICall, company: str, request: ProjectRequest):
|
||||||
)
|
|
||||||
def update(call: APICall):
|
|
||||||
"""
|
"""
|
||||||
update
|
update
|
||||||
|
|
||||||
@ -309,9 +315,7 @@ def update(call: APICall):
|
|||||||
call.data, create_fields, Project.get_fields(), discard_none_values=False
|
call.data, create_fields, Project.get_fields(), discard_none_values=False
|
||||||
)
|
)
|
||||||
conform_tag_fields(call, fields, validate=True)
|
conform_tag_fields(call, fields, validate=True)
|
||||||
updated = ProjectBLL.update(
|
updated = ProjectBLL.update(company=company, project_id=request.project, **fields)
|
||||||
company=call.identity.company, project_id=call.data["project"], **fields
|
|
||||||
)
|
|
||||||
conform_output_tags(call, fields)
|
conform_output_tags(call, fields)
|
||||||
call.result.data_model = UpdateResponse(updated=updated, fields=fields)
|
call.result.data_model = UpdateResponse(updated=updated, fields=fields)
|
||||||
|
|
||||||
@ -375,7 +379,6 @@ def delete(call: APICall, company_id: str, request: DeleteRequest):
|
|||||||
def get_unique_metric_variants(
|
def get_unique_metric_variants(
|
||||||
call: APICall, company_id: str, request: GetUniqueMetricsRequest
|
call: APICall, company_id: str, request: GetUniqueMetricsRequest
|
||||||
):
|
):
|
||||||
|
|
||||||
metrics = project_queries.get_unique_metric_variants(
|
metrics = project_queries.get_unique_metric_variants(
|
||||||
company_id,
|
company_id,
|
||||||
[request.project] if request.project else None,
|
[request.project] if request.project else None,
|
||||||
@ -429,7 +432,6 @@ def get_model_metadata_values(
|
|||||||
request_data_model=GetParamsRequest,
|
request_data_model=GetParamsRequest,
|
||||||
)
|
)
|
||||||
def get_hyper_parameters(call: APICall, company_id: str, request: GetParamsRequest):
|
def get_hyper_parameters(call: APICall, company_id: str, request: GetParamsRequest):
|
||||||
|
|
||||||
total, remaining, parameters = project_queries.get_aggregated_project_parameters(
|
total, remaining, parameters = project_queries.get_aggregated_project_parameters(
|
||||||
company_id,
|
company_id,
|
||||||
project_ids=[request.project] if request.project else None,
|
project_ids=[request.project] if request.project else None,
|
||||||
|
@ -3,7 +3,11 @@ from datetime import datetime
|
|||||||
from pyhocon.config_tree import NoneValue
|
from pyhocon.config_tree import NoneValue
|
||||||
|
|
||||||
from apiserver.apierrors import errors
|
from apiserver.apierrors import errors
|
||||||
from apiserver.apimodels.server import ReportStatsOptionRequest, ReportStatsOptionResponse
|
from apiserver.apimodels.server import (
|
||||||
|
ReportStatsOptionRequest,
|
||||||
|
ReportStatsOptionResponse,
|
||||||
|
GetConfigRequest,
|
||||||
|
)
|
||||||
from apiserver.bll.statistics.stats_reporter import StatisticsReporter
|
from apiserver.bll.statistics.stats_reporter import StatisticsReporter
|
||||||
from apiserver.config_repo import config
|
from apiserver.config_repo import config
|
||||||
from apiserver.config.info import get_version, get_build_number, get_commit_number
|
from apiserver.config.info import get_version, get_build_number, get_commit_number
|
||||||
@ -22,8 +26,8 @@ def get_stats(call: APICall):
|
|||||||
|
|
||||||
|
|
||||||
@endpoint("server.config")
|
@endpoint("server.config")
|
||||||
def get_config(call: APICall):
|
def get_config(call: APICall, _, request: GetConfigRequest):
|
||||||
path = call.data.get("path")
|
path = request.path
|
||||||
if path:
|
if path:
|
||||||
c = dict(config.get(path))
|
c = dict(config.get(path))
|
||||||
else:
|
else:
|
||||||
|
@ -7,7 +7,11 @@ from mongoengine import Q
|
|||||||
|
|
||||||
from apiserver.apierrors import errors
|
from apiserver.apierrors import errors
|
||||||
from apiserver.apimodels.base import UpdateResponse
|
from apiserver.apimodels.base import UpdateResponse
|
||||||
from apiserver.apimodels.users import CreateRequest, SetPreferencesRequest
|
from apiserver.apimodels.users import (
|
||||||
|
CreateRequest,
|
||||||
|
SetPreferencesRequest,
|
||||||
|
UserRequest,
|
||||||
|
)
|
||||||
from apiserver.bll.project import ProjectBLL
|
from apiserver.bll.project import ProjectBLL
|
||||||
from apiserver.bll.user import UserBLL
|
from apiserver.bll.user import UserBLL
|
||||||
from apiserver.config_repo import config
|
from apiserver.config_repo import config
|
||||||
@ -48,13 +52,13 @@ def get_user(call, company_id, user_id, only=None):
|
|||||||
return res.to_proper_dict()
|
return res.to_proper_dict()
|
||||||
|
|
||||||
|
|
||||||
@endpoint("users.get_by_id", required_fields=["user"])
|
@endpoint("users.get_by_id")
|
||||||
def get_by_id(call: APICall, company_id, _):
|
def get_by_id(call: APICall, company_id, request: UserRequest):
|
||||||
user_id = call.data["user"]
|
user_id = request.user
|
||||||
call.result.data = {"user": get_user(call, company_id, user_id)}
|
call.result.data = {"user": get_user(call, company_id, user_id)}
|
||||||
|
|
||||||
|
|
||||||
@endpoint("users.get_all_ex", required_fields=[])
|
@endpoint("users.get_all_ex")
|
||||||
def get_all_ex(call: APICall, company_id, _):
|
def get_all_ex(call: APICall, company_id, _):
|
||||||
with translate_errors_context("retrieving users"):
|
with translate_errors_context("retrieving users"):
|
||||||
res = User.get_many_with_join(company=company_id, query_dict=call.data)
|
res = User.get_many_with_join(company=company_id, query_dict=call.data)
|
||||||
@ -62,7 +66,7 @@ def get_all_ex(call: APICall, company_id, _):
|
|||||||
call.result.data = {"users": res}
|
call.result.data = {"users": res}
|
||||||
|
|
||||||
|
|
||||||
@endpoint("users.get_all_ex", min_version="2.8", required_fields=[])
|
@endpoint("users.get_all_ex", min_version="2.8")
|
||||||
def get_all_ex2_8(call: APICall, company_id, _):
|
def get_all_ex2_8(call: APICall, company_id, _):
|
||||||
with translate_errors_context("retrieving users"):
|
with translate_errors_context("retrieving users"):
|
||||||
data = call.data
|
data = call.data
|
||||||
@ -83,7 +87,7 @@ def get_all_ex2_8(call: APICall, company_id, _):
|
|||||||
call.result.data = {"users": res}
|
call.result.data = {"users": res}
|
||||||
|
|
||||||
|
|
||||||
@endpoint("users.get_all", required_fields=[])
|
@endpoint("users.get_all")
|
||||||
def get_all(call: APICall, company_id, _):
|
def get_all(call: APICall, company_id, _):
|
||||||
with translate_errors_context("retrieving users"):
|
with translate_errors_context("retrieving users"):
|
||||||
res = User.get_many(
|
res = User.get_many(
|
||||||
@ -138,9 +142,9 @@ def create(call: APICall):
|
|||||||
UserBLL.create(call.data_model)
|
UserBLL.create(call.data_model)
|
||||||
|
|
||||||
|
|
||||||
@endpoint("users.delete", required_fields=["user"])
|
@endpoint("users.delete")
|
||||||
def delete(call: APICall):
|
def delete(_: APICall, __, request: UserRequest):
|
||||||
UserBLL.delete(call.data["user"])
|
UserBLL.delete(request.user)
|
||||||
|
|
||||||
|
|
||||||
def update_user(user_id, company_id, data: dict) -> Tuple[int, dict]:
|
def update_user(user_id, company_id, data: dict) -> Tuple[int, dict]:
|
||||||
@ -159,9 +163,9 @@ def update_user(user_id, company_id, data: dict) -> Tuple[int, dict]:
|
|||||||
return User.safe_update(company_id, user_id, partial_update_dict)
|
return User.safe_update(company_id, user_id, partial_update_dict)
|
||||||
|
|
||||||
|
|
||||||
@endpoint("users.update", required_fields=["user"], response_data_model=UpdateResponse)
|
@endpoint("users.update", response_data_model=UpdateResponse)
|
||||||
def update(call, company_id, _):
|
def update(call, company_id, request: UserRequest):
|
||||||
user_id = call.data["user"]
|
user_id = request.user
|
||||||
update_count, updated_fields = update_user(user_id, company_id, call.data)
|
update_count, updated_fields = update_user(user_id, company_id, call.data)
|
||||||
call.result.data_model = UpdateResponse(updated=update_count, fields=updated_fields)
|
call.result.data_model = UpdateResponse(updated=update_count, fields=updated_fields)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user