mirror of
https://github.com/clearml/clearml-server
synced 2025-05-28 08:58:49 +00:00
Fix type annotations
Fix obtaining events for tasks moved from private to public Fix assert_exists() to return company_origin if requested
This commit is contained in:
parent
c67a56eb8d
commit
22e9c2b7eb
@ -115,7 +115,7 @@ class EventMetrics:
|
||||
company=company_id,
|
||||
query=Q(id__in=task_ids),
|
||||
allow_public=allow_public,
|
||||
override_projection=("id", "name", "company"),
|
||||
override_projection=("id", "name", "company", "company_origin"),
|
||||
return_dicts=False,
|
||||
)
|
||||
if len(task_objs) < len(task_ids):
|
||||
@ -123,7 +123,7 @@ class EventMetrics:
|
||||
raise errors.bad_request.InvalidTaskId(company=company_id, ids=invalid)
|
||||
task_name_by_id = {t.id: t.name for t in task_objs}
|
||||
|
||||
companies = {t.company for t in task_objs}
|
||||
companies = {t.get_index_company() for t in task_objs}
|
||||
if len(companies) > 1:
|
||||
raise errors.bad_request.InvalidTaskId(
|
||||
"only tasks from the same company are supported"
|
||||
|
@ -126,7 +126,9 @@ class TaskBLL(object):
|
||||
return_dicts=False,
|
||||
)
|
||||
if only:
|
||||
q = q.only(*only)
|
||||
# Make sure to reset fields filters (some fields are excluded by default) since this
|
||||
# is an internal call and specific fields were requested.
|
||||
q = q.all_fields().only(*only)
|
||||
|
||||
if q.count() != len(ids):
|
||||
raise errors.bad_request.InvalidTaskId(ids=task_ids)
|
||||
|
@ -773,7 +773,7 @@ class DbModelMixin(GetMixin, ProperDictMixin, UpdateMixin):
|
||||
):
|
||||
if enabled:
|
||||
items = list(cls.objects(id__in=ids, company=company_id).only("id"))
|
||||
update = dict(set__company_origin=company_id, unset__company=1)
|
||||
update = dict(set__company_origin=company_id, set__company="")
|
||||
else:
|
||||
items = list(
|
||||
cls.objects(
|
||||
|
@ -217,3 +217,13 @@ class Task(AttributedDocument):
|
||||
hyperparams = SafeMapField(field=SafeMapField(EmbeddedDocumentField(ParamsItem)))
|
||||
configuration = SafeMapField(field=EmbeddedDocumentField(ConfigurationItem))
|
||||
runtime = SafeDictField(default=dict)
|
||||
|
||||
def get_index_company(self) -> str:
|
||||
"""
|
||||
Returns the company ID used for locating indices containing task data.
|
||||
In case the task has a valid company, this is the company ID.
|
||||
Otherwise, if the task has a company_origin, this is a task that has been made public and the
|
||||
origin company should be used.
|
||||
Otherwise, an empty company is used.
|
||||
"""
|
||||
return self.company or self.company_origin or ""
|
||||
|
@ -1,4 +1,4 @@
|
||||
from typing import Text, Sequence, Callable, Union
|
||||
from typing import Text, Sequence, Callable, Union, Type
|
||||
|
||||
from funcsigs import signature
|
||||
from jsonmodels import models
|
||||
@ -18,8 +18,8 @@ def endpoint(
|
||||
name: Text,
|
||||
min_version: Text = "1.0",
|
||||
required_fields: Sequence[Text] = None,
|
||||
request_data_model: models.Base = None,
|
||||
response_data_model: models.Base = None,
|
||||
request_data_model: Type[models.Base] = None,
|
||||
response_data_model: Type[models.Base] = None,
|
||||
validate_schema=False,
|
||||
):
|
||||
""" Endpoint decorator, used to declare a method as an endpoint handler """
|
||||
|
@ -50,13 +50,13 @@ def add_batch(call: APICall, company_id, _):
|
||||
def get_task_log_v1_5(call, company_id, _):
|
||||
task_id = call.data["task"]
|
||||
task = task_bll.assert_exists(
|
||||
company_id, task_id, allow_public=True, only=("company",)
|
||||
company_id, task_id, allow_public=True, only=("company", "company_origin")
|
||||
)[0]
|
||||
order = call.data.get("order") or "desc"
|
||||
scroll_id = call.data.get("scroll_id")
|
||||
batch_size = int(call.data.get("batch_size") or 500)
|
||||
events, scroll_id, total_events = event_bll.scroll_task_events(
|
||||
task.company,
|
||||
task.get_index_company(),
|
||||
task_id,
|
||||
order,
|
||||
event_type="log",
|
||||
@ -72,7 +72,7 @@ def get_task_log_v1_5(call, company_id, _):
|
||||
def get_task_log_v1_7(call, company_id, _):
|
||||
task_id = call.data["task"]
|
||||
task = task_bll.assert_exists(
|
||||
company_id, task_id, allow_public=True, only=("company",)
|
||||
company_id, task_id, allow_public=True, only=("company", "company_origin")
|
||||
)[0]
|
||||
|
||||
order = call.data.get("order") or "desc"
|
||||
@ -83,7 +83,7 @@ def get_task_log_v1_7(call, company_id, _):
|
||||
scroll_order = "asc" if (from_ == "head") else "desc"
|
||||
|
||||
events, scroll_id, total_events = event_bll.scroll_task_events(
|
||||
company_id=task.company,
|
||||
company_id=task.get_index_company(),
|
||||
task_id=task_id,
|
||||
order=scroll_order,
|
||||
event_type="log",
|
||||
@ -103,11 +103,11 @@ def get_task_log_v1_7(call, company_id, _):
|
||||
def get_task_log(call, company_id, request: LogEventsRequest):
|
||||
task_id = request.task
|
||||
task = task_bll.assert_exists(
|
||||
company_id, task_id, allow_public=True, only=("company",)
|
||||
company_id, task_id, allow_public=True, only=("company", "company_origin")
|
||||
)[0]
|
||||
|
||||
res = event_bll.log_events_iterator.get_task_events(
|
||||
company_id=task.company,
|
||||
company_id=task.get_index_company(),
|
||||
task_id=task_id,
|
||||
batch_size=request.batch_size,
|
||||
navigate_earlier=request.navigate_earlier,
|
||||
@ -131,7 +131,7 @@ def get_task_log(call, company_id, request: LogEventsRequest):
|
||||
def download_task_log(call, company_id, _):
|
||||
task_id = call.data["task"]
|
||||
task = task_bll.assert_exists(
|
||||
company_id, task_id, allow_public=True, only=("company",)
|
||||
company_id, task_id, allow_public=True, only=("company", "company_origin")
|
||||
)[0]
|
||||
|
||||
line_type = call.data.get("line_type", "json").lower()
|
||||
@ -175,7 +175,7 @@ def download_task_log(call, company_id, _):
|
||||
batch_size = 1000
|
||||
while True:
|
||||
log_events, scroll_id, _ = event_bll.scroll_task_events(
|
||||
task.company,
|
||||
task.get_index_company(),
|
||||
task_id,
|
||||
order="asc",
|
||||
event_type="log",
|
||||
@ -211,11 +211,11 @@ def download_task_log(call, company_id, _):
|
||||
def get_vector_metrics_and_variants(call, company_id, _):
|
||||
task_id = call.data["task"]
|
||||
task = task_bll.assert_exists(
|
||||
company_id, task_id, allow_public=True, only=("company",)
|
||||
company_id, task_id, allow_public=True, only=("company", "company_origin")
|
||||
)[0]
|
||||
call.result.data = dict(
|
||||
metrics=event_bll.get_metrics_and_variants(
|
||||
task.company, task_id, "training_stats_vector"
|
||||
task.get_index_company(), task_id, "training_stats_vector"
|
||||
)
|
||||
)
|
||||
|
||||
@ -224,11 +224,11 @@ def get_vector_metrics_and_variants(call, company_id, _):
|
||||
def get_scalar_metrics_and_variants(call, company_id, _):
|
||||
task_id = call.data["task"]
|
||||
task = task_bll.assert_exists(
|
||||
company_id, task_id, allow_public=True, only=("company",)
|
||||
company_id, task_id, allow_public=True, only=("company", "company_origin")
|
||||
)[0]
|
||||
call.result.data = dict(
|
||||
metrics=event_bll.get_metrics_and_variants(
|
||||
task.company, task_id, "training_stats_scalar"
|
||||
task.get_index_company(), task_id, "training_stats_scalar"
|
||||
)
|
||||
)
|
||||
|
||||
@ -241,12 +241,12 @@ def get_scalar_metrics_and_variants(call, company_id, _):
|
||||
def vector_metrics_iter_histogram(call, company_id, _):
|
||||
task_id = call.data["task"]
|
||||
task = task_bll.assert_exists(
|
||||
company_id, task_id, allow_public=True, only=("company",)
|
||||
company_id, task_id, allow_public=True, only=("company", "company_origin")
|
||||
)[0]
|
||||
metric = call.data["metric"]
|
||||
variant = call.data["variant"]
|
||||
iterations, vectors = event_bll.get_vector_metrics_per_iter(
|
||||
task.company, task_id, metric, variant
|
||||
task.get_index_company(), task_id, metric, variant
|
||||
)
|
||||
call.result.data = dict(
|
||||
metric=metric, variant=variant, vectors=vectors, iterations=iterations
|
||||
@ -262,10 +262,10 @@ def get_task_events(call, company_id, _):
|
||||
order = call.data.get("order") or "asc"
|
||||
|
||||
task = task_bll.assert_exists(
|
||||
company_id, task_id, allow_public=True, only=("company",)
|
||||
company_id, task_id, allow_public=True, only=("company", "company_origin")
|
||||
)[0]
|
||||
result = event_bll.get_task_events(
|
||||
task.company,
|
||||
task.get_index_company(),
|
||||
task_id,
|
||||
sort=[{"timestamp": {"order": order}}],
|
||||
event_type=event_type,
|
||||
@ -288,10 +288,10 @@ def get_scalar_metric_data(call, company_id, _):
|
||||
scroll_id = call.data.get("scroll_id")
|
||||
|
||||
task = task_bll.assert_exists(
|
||||
company_id, task_id, allow_public=True, only=("company",)
|
||||
company_id, task_id, allow_public=True, only=("company", "company_origin")
|
||||
)[0]
|
||||
result = event_bll.get_task_events(
|
||||
task.company,
|
||||
task.get_index_company(),
|
||||
task_id,
|
||||
event_type="training_stats_scalar",
|
||||
sort=[{"iter": {"order": "desc"}}],
|
||||
@ -311,12 +311,13 @@ def get_scalar_metric_data(call, company_id, _):
|
||||
def get_task_latest_scalar_values(call, company_id, _):
|
||||
task_id = call.data["task"]
|
||||
task = task_bll.assert_exists(
|
||||
company_id, task_id, allow_public=True, only=("company",)
|
||||
company_id, task_id, allow_public=True, only=("company", "company_origin")
|
||||
)[0]
|
||||
index_company = task.get_index_company()
|
||||
metrics, last_timestamp = event_bll.get_task_latest_scalar_values(
|
||||
task.company, task_id
|
||||
index_company, task_id
|
||||
)
|
||||
es_index = EventMetrics.get_index_name(task.company, "*")
|
||||
es_index = EventMetrics.get_index_name(index_company, "*")
|
||||
last_iters = event_bll.get_last_iters(es_index, task_id, None, 1)
|
||||
call.result.data = dict(
|
||||
metrics=metrics,
|
||||
@ -336,10 +337,10 @@ def scalar_metrics_iter_histogram(
|
||||
call, company_id, request: ScalarMetricsIterHistogramRequest
|
||||
):
|
||||
task = task_bll.assert_exists(
|
||||
company_id, request.task, allow_public=True, only=("company",)
|
||||
company_id, request.task, allow_public=True, only=("company", "company_origin")
|
||||
)[0]
|
||||
metrics = event_bll.metrics.get_scalar_metrics_average_per_iter(
|
||||
task.company, task_id=request.task, samples=request.samples, key=request.key
|
||||
task.get_index_company(), task_id=request.task, samples=request.samples, key=request.key
|
||||
)
|
||||
call.result.data = metrics
|
||||
|
||||
@ -374,12 +375,12 @@ def get_multi_task_plots_v1_7(call, company_id, _):
|
||||
|
||||
tasks = task_bll.assert_exists(
|
||||
company_id=company_id,
|
||||
only=("id", "name", "company"),
|
||||
only=("id", "name", "company", "company_origin"),
|
||||
task_ids=task_ids,
|
||||
allow_public=True,
|
||||
)
|
||||
|
||||
companies = {t.company for t in tasks}
|
||||
companies = {t.get_index_company() for t in tasks}
|
||||
if len(companies) > 1:
|
||||
raise errors.bad_request.InvalidTaskId(
|
||||
"only tasks from the same company are supported"
|
||||
@ -417,12 +418,12 @@ def get_multi_task_plots(call, company_id, req_model):
|
||||
|
||||
tasks = task_bll.assert_exists(
|
||||
company_id=call.identity.company,
|
||||
only=("id", "name", "company"),
|
||||
only=("id", "name", "company", "company_origin"),
|
||||
task_ids=task_ids,
|
||||
allow_public=True,
|
||||
)
|
||||
|
||||
companies = {t.company for t in tasks}
|
||||
companies = {t.get_index_company() for t in tasks}
|
||||
if len(companies) > 1:
|
||||
raise errors.bad_request.InvalidTaskId(
|
||||
"only tasks from the same company are supported"
|
||||
@ -458,7 +459,7 @@ def get_task_plots_v1_7(call, company_id, _):
|
||||
scroll_id = call.data.get("scroll_id")
|
||||
|
||||
task = task_bll.assert_exists(
|
||||
company_id, task_id, allow_public=True, only=("company",)
|
||||
company_id, task_id, allow_public=True, only=("company", "company_origin")
|
||||
)[0]
|
||||
# events, next_scroll_id, total_events = event_bll.get_task_events(
|
||||
# company, task_id,
|
||||
@ -469,7 +470,7 @@ def get_task_plots_v1_7(call, company_id, _):
|
||||
|
||||
# get last 10K events by iteration and group them by unique metric+variant, returning top events for combination
|
||||
result = event_bll.get_task_events(
|
||||
task.company,
|
||||
task.get_index_company(),
|
||||
task_id,
|
||||
event_type="plot",
|
||||
sort=[{"iter": {"order": "desc"}}],
|
||||
@ -494,10 +495,10 @@ def get_task_plots(call, company_id, _):
|
||||
scroll_id = call.data.get("scroll_id")
|
||||
|
||||
task = task_bll.assert_exists(
|
||||
company_id, task_id, allow_public=True, only=("company",)
|
||||
company_id, task_id, allow_public=True, only=("company", "company_origin")
|
||||
)[0]
|
||||
result = event_bll.get_task_plots(
|
||||
task.company,
|
||||
task.get_index_company(),
|
||||
tasks=[task_id],
|
||||
sort=[{"iter": {"order": "desc"}}],
|
||||
last_iterations_per_plot=iters,
|
||||
@ -521,7 +522,7 @@ def get_debug_images_v1_7(call, company_id, _):
|
||||
scroll_id = call.data.get("scroll_id")
|
||||
|
||||
task = task_bll.assert_exists(
|
||||
company_id, task_id, allow_public=True, only=("company",)
|
||||
company_id, task_id, allow_public=True, only=("company", "company_origin")
|
||||
)[0]
|
||||
# events, next_scroll_id, total_events = event_bll.get_task_events(
|
||||
# company, task_id,
|
||||
@ -532,7 +533,7 @@ def get_debug_images_v1_7(call, company_id, _):
|
||||
|
||||
# get last 10K events by iteration and group them by unique metric+variant, returning top events for combination
|
||||
result = event_bll.get_task_events(
|
||||
task.company,
|
||||
task.get_index_company(),
|
||||
task_id,
|
||||
event_type="training_debug_image",
|
||||
sort=[{"iter": {"order": "desc"}}],
|
||||
@ -558,10 +559,10 @@ def get_debug_images_v1_8(call, company_id, _):
|
||||
scroll_id = call.data.get("scroll_id")
|
||||
|
||||
task = task_bll.assert_exists(
|
||||
company_id, task_id, allow_public=True, only=("company",)
|
||||
company_id, task_id, allow_public=True, only=("company", "company_origin")
|
||||
)[0]
|
||||
result = event_bll.get_task_events(
|
||||
task.company,
|
||||
task.get_index_company(),
|
||||
task_id,
|
||||
event_type="training_debug_image",
|
||||
sort=[{"iter": {"order": "desc"}}],
|
||||
@ -589,10 +590,10 @@ def get_debug_images_v1_8(call, company_id, _):
|
||||
def get_debug_images(call, company_id, request: DebugImagesRequest):
|
||||
task_ids = {m.task for m in request.metrics}
|
||||
tasks = task_bll.assert_exists(
|
||||
company_id, task_ids=task_ids, allow_public=True, only=("company",)
|
||||
company_id, task_ids=task_ids, allow_public=True, only=("company", "company_origin")
|
||||
)
|
||||
|
||||
companies = {t.company for t in tasks}
|
||||
companies = {t.get_index_company() for t in tasks}
|
||||
if len(companies) > 1:
|
||||
raise errors.bad_request.InvalidTaskId(
|
||||
"only tasks from the same company are supported"
|
||||
@ -626,10 +627,10 @@ def get_debug_images(call, company_id, request: DebugImagesRequest):
|
||||
@endpoint("events.get_task_metrics", request_data_model=TaskMetricsRequest)
|
||||
def get_tasks_metrics(call: APICall, company_id, request: TaskMetricsRequest):
|
||||
task = task_bll.assert_exists(
|
||||
company_id, task_ids=request.tasks, allow_public=True, only=("company",)
|
||||
company_id, task_ids=request.tasks, allow_public=True, only=("company", "company_origin")
|
||||
)[0]
|
||||
res = event_bll.metrics.get_tasks_metrics(
|
||||
task.company, task_ids=request.tasks, event_type=request.event_type
|
||||
task.get_index_company(), task_ids=request.tasks, event_type=request.event_type
|
||||
)
|
||||
call.result.data = {
|
||||
"metrics": [{"task": task, "metrics": metrics} for (task, metrics) in res]
|
||||
|
Loading…
Reference in New Issue
Block a user