mirror of
https://github.com/clearml/clearml-server
synced 2025-06-23 08:45:30 +00:00
Add support for debug images history using events.get_debug_image_event and events.get_debug_image_iterations
Remove untracked files
This commit is contained in:
parent
f084f6b9e7
commit
323b5db07c
@ -51,6 +51,20 @@ class DebugImagesRequest(Base):
|
|||||||
scroll_id: str = StringField()
|
scroll_id: str = StringField()
|
||||||
|
|
||||||
|
|
||||||
|
class TaskMetricVariant(Base):
|
||||||
|
task: str = StringField(required=True)
|
||||||
|
metric: str = StringField(required=True)
|
||||||
|
variant: str = StringField(required=True)
|
||||||
|
|
||||||
|
|
||||||
|
class DebugImageIterationsRequest(TaskMetricVariant):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class DebugImageEventRequest(TaskMetricVariant):
|
||||||
|
iteration: Optional[int] = IntField()
|
||||||
|
|
||||||
|
|
||||||
class LogOrderEnum(StringEnum):
|
class LogOrderEnum(StringEnum):
|
||||||
asc = auto()
|
asc = auto()
|
||||||
desc = auto()
|
desc = auto()
|
||||||
|
@ -1,155 +0,0 @@
|
|||||||
from flask import Flask, request, Response
|
|
||||||
|
|
||||||
from werkzeug.exceptions import BadRequest
|
|
||||||
|
|
||||||
from apiserver.apierrors.base import BaseError
|
|
||||||
from apiserver.config import config
|
|
||||||
from apiserver.service_repo import ServiceRepo, APICall
|
|
||||||
from apiserver.service_repo.auth import AuthType
|
|
||||||
from apiserver.service_repo.errors import PathParsingError
|
|
||||||
from apiserver.timing_context import TimingContext
|
|
||||||
from apiserver.utilities import json
|
|
||||||
|
|
||||||
log = config.logger(__file__)
|
|
||||||
|
|
||||||
|
|
||||||
def before_app_first_request():
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
def before_request():
|
|
||||||
if request.method == "OPTIONS":
|
|
||||||
return "", 200
|
|
||||||
if "/static/" in request.path:
|
|
||||||
return
|
|
||||||
if request.path.startswith("/rq"):
|
|
||||||
return
|
|
||||||
|
|
||||||
try:
|
|
||||||
call = create_api_call(request)
|
|
||||||
content, content_type = ServiceRepo.handle_call(call)
|
|
||||||
headers = {}
|
|
||||||
if call.result.filename:
|
|
||||||
headers[
|
|
||||||
"Content-Disposition"
|
|
||||||
] = f"attachment; filename={call.result.filename}"
|
|
||||||
|
|
||||||
if call.result.headers:
|
|
||||||
headers.update(call.result.headers)
|
|
||||||
|
|
||||||
response = Response(
|
|
||||||
content, mimetype=content_type, status=call.result.code, headers=headers
|
|
||||||
)
|
|
||||||
|
|
||||||
if call.result.cookies:
|
|
||||||
for key, value in call.result.cookies.items():
|
|
||||||
if value is None:
|
|
||||||
response.set_cookie(
|
|
||||||
key, "", expires=0, **config.get("apiserver.auth.cookies")
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
response.set_cookie(
|
|
||||||
key, value, **config.get("apiserver.auth.cookies")
|
|
||||||
)
|
|
||||||
|
|
||||||
return response
|
|
||||||
except Exception as ex:
|
|
||||||
log.exception(f"Failed processing request {request.url}: {ex}")
|
|
||||||
return f"Failed processing request {request.url}", 500
|
|
||||||
|
|
||||||
|
|
||||||
def update_call_data(call, req):
|
|
||||||
""" Use request payload/form to fill call data or batched data """
|
|
||||||
if req.content_type == "application/json-lines":
|
|
||||||
items = []
|
|
||||||
for i, line in enumerate(req.data.splitlines()):
|
|
||||||
try:
|
|
||||||
event = json.loads(line)
|
|
||||||
if not isinstance(event, dict):
|
|
||||||
raise BadRequest(
|
|
||||||
f"json lines must contain objects, found: {type(event).__name__}"
|
|
||||||
)
|
|
||||||
items.append(event)
|
|
||||||
except ValueError as e:
|
|
||||||
msg = f"{e} in batch item #{i}"
|
|
||||||
req.on_json_loading_failed(msg)
|
|
||||||
call.batched_data = items
|
|
||||||
else:
|
|
||||||
json_body = req.get_json(force=True, silent=False) if req.data else None
|
|
||||||
# merge form and args
|
|
||||||
form = req.form.copy()
|
|
||||||
form.update(req.args)
|
|
||||||
form = form.to_dict()
|
|
||||||
# convert string numbers to floats
|
|
||||||
for key in form:
|
|
||||||
if form[key].replace(".", "", 1).isdigit():
|
|
||||||
if "." in form[key]:
|
|
||||||
form[key] = float(form[key])
|
|
||||||
else:
|
|
||||||
form[key] = int(form[key])
|
|
||||||
elif form[key].lower() == "true":
|
|
||||||
form[key] = True
|
|
||||||
elif form[key].lower() == "false":
|
|
||||||
form[key] = False
|
|
||||||
call.data = json_body or form or {}
|
|
||||||
|
|
||||||
|
|
||||||
def _call_or_empty_with_error(call, req, msg, code=500, subcode=0):
|
|
||||||
call = call or APICall(
|
|
||||||
"", remote_addr=req.remote_addr, headers=dict(req.headers), files=req.files
|
|
||||||
)
|
|
||||||
call.set_error_result(msg=msg, code=code, subcode=subcode)
|
|
||||||
return call
|
|
||||||
|
|
||||||
|
|
||||||
def create_api_call(req):
|
|
||||||
call = None
|
|
||||||
try:
|
|
||||||
# Parse the request path
|
|
||||||
endpoint_version, endpoint_name = ServiceRepo.parse_endpoint_path(req.path)
|
|
||||||
|
|
||||||
# Resolve authorization: if cookies contain an authorization token, use it as a starting point.
|
|
||||||
# in any case, request headers always take precedence.
|
|
||||||
auth_cookie = req.cookies.get(
|
|
||||||
config.get("apiserver.auth.session_auth_cookie_name")
|
|
||||||
)
|
|
||||||
headers = (
|
|
||||||
{}
|
|
||||||
if not auth_cookie
|
|
||||||
else {"Authorization": f"{AuthType.bearer_token} {auth_cookie}"}
|
|
||||||
)
|
|
||||||
headers.update(
|
|
||||||
list(req.headers.items())
|
|
||||||
) # add (possibly override with) the headers
|
|
||||||
|
|
||||||
# Construct call instance
|
|
||||||
call = APICall(
|
|
||||||
endpoint_name=endpoint_name,
|
|
||||||
remote_addr=req.remote_addr,
|
|
||||||
endpoint_version=endpoint_version,
|
|
||||||
headers=headers,
|
|
||||||
files=req.files,
|
|
||||||
host=req.host,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Update call data from request
|
|
||||||
with TimingContext("preprocess", "update_call_data"):
|
|
||||||
update_call_data(call, req)
|
|
||||||
|
|
||||||
except PathParsingError as ex:
|
|
||||||
call = _call_or_empty_with_error(call, req, ex.args[0], 400)
|
|
||||||
call.log_api = False
|
|
||||||
except BadRequest as ex:
|
|
||||||
call = _call_or_empty_with_error(call, req, ex.description, 400)
|
|
||||||
except BaseError as ex:
|
|
||||||
call = _call_or_empty_with_error(call, req, ex.msg, ex.code, ex.subcode)
|
|
||||||
except Exception as ex:
|
|
||||||
log.exception("Error creating call")
|
|
||||||
call = _call_or_empty_with_error(call, req, ex.args[0] if ex.args else type(ex).__name__, 500)
|
|
||||||
|
|
||||||
return call
|
|
||||||
|
|
||||||
|
|
||||||
def register_routes(app: Flask):
|
|
||||||
app.before_first_request(before_app_first_request)
|
|
||||||
app.before_request(before_request)
|
|
@ -22,6 +22,7 @@ from apiserver.database.errors import translate_errors_context
|
|||||||
from apiserver.database.model.task.metrics import MetricEventStats
|
from apiserver.database.model.task.metrics import MetricEventStats
|
||||||
from apiserver.database.model.task.task import Task
|
from apiserver.database.model.task.task import Task
|
||||||
from apiserver.timing_context import TimingContext
|
from apiserver.timing_context import TimingContext
|
||||||
|
from apiserver.utilities.dicts import nested_get
|
||||||
|
|
||||||
|
|
||||||
class VariantScrollState(Base):
|
class VariantScrollState(Base):
|
||||||
@ -465,3 +466,103 @@ class DebugImagesIterator:
|
|||||||
if events_to_remove:
|
if events_to_remove:
|
||||||
it["events"] = [ev for ev in it["events"] if ev not in events_to_remove]
|
it["events"] = [ev for ev in it["events"] if ev not in events_to_remove]
|
||||||
return [it for it in iterations if it["events"]]
|
return [it for it in iterations if it["events"]]
|
||||||
|
|
||||||
|
def get_debug_image_event(
|
||||||
|
self,
|
||||||
|
company_id: str,
|
||||||
|
task: str,
|
||||||
|
metric: str,
|
||||||
|
variant: str,
|
||||||
|
iteration: Optional[int] = None,
|
||||||
|
) -> Optional[dict]:
|
||||||
|
"""
|
||||||
|
Get the debug image for the requested iteration or the latest before it
|
||||||
|
If the iteration is not passed then get the latest event
|
||||||
|
"""
|
||||||
|
es_index = EventMetrics.get_index_name(company_id, self.EVENT_TYPE)
|
||||||
|
if not self.es.indices.exists(es_index):
|
||||||
|
return None
|
||||||
|
|
||||||
|
must_conditions = [
|
||||||
|
{"term": {"task": task}},
|
||||||
|
{"term": {"metric": metric}},
|
||||||
|
{"term": {"variant": variant}},
|
||||||
|
{"exists": {"field": "url"}},
|
||||||
|
]
|
||||||
|
if iteration is not None:
|
||||||
|
must_conditions.append({"range": {"iter": {"lte": iteration}}})
|
||||||
|
|
||||||
|
es_req = {
|
||||||
|
"size": 1,
|
||||||
|
"sort": {"iter": "desc"},
|
||||||
|
"query": {"bool": {"must": must_conditions}}
|
||||||
|
}
|
||||||
|
|
||||||
|
with translate_errors_context(), TimingContext("es", "get_debug_image_event"):
|
||||||
|
es_res = self.es.search(index=es_index, body=es_req, routing=task)
|
||||||
|
|
||||||
|
hits = nested_get(es_res, ("hits", "hits"))
|
||||||
|
if not hits:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return hits[0]["_source"]
|
||||||
|
|
||||||
|
def get_debug_image_iterations(
|
||||||
|
self, company_id: str, task: str, metric: str, variant: str
|
||||||
|
) -> Tuple[int, int]:
|
||||||
|
"""
|
||||||
|
Return valid min and max iterations that the task reported images
|
||||||
|
The min iteration is the lowest iteration that contains non-recycled image url
|
||||||
|
"""
|
||||||
|
es_index = EventMetrics.get_index_name(company_id, self.EVENT_TYPE)
|
||||||
|
if not self.es.indices.exists(es_index):
|
||||||
|
return 0, 0
|
||||||
|
|
||||||
|
es_req: dict = {
|
||||||
|
"size": 1,
|
||||||
|
"sort": {"iter": "desc"},
|
||||||
|
"query": {
|
||||||
|
"bool": {
|
||||||
|
"must": [
|
||||||
|
{"term": {"task": task}},
|
||||||
|
{"term": {"metric": metric}},
|
||||||
|
{"term": {"variant": variant}},
|
||||||
|
{"exists": {"field": "url"}},
|
||||||
|
]
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"aggs": {
|
||||||
|
"url_min": {
|
||||||
|
"terms": {
|
||||||
|
"field": "url",
|
||||||
|
"order": {"max_iter": "asc"},
|
||||||
|
"size": 1, # we need only one url from the least recent iteration
|
||||||
|
},
|
||||||
|
"aggs": {
|
||||||
|
"max_iter": {"max": {"field": "iter"}},
|
||||||
|
"iters": {
|
||||||
|
"top_hits": {
|
||||||
|
"sort": {"iter": {"order": "desc"}},
|
||||||
|
"size": 1,
|
||||||
|
"_source": "iter",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
with translate_errors_context(), TimingContext("es", "get_debug_image_iterations"):
|
||||||
|
es_res = self.es.search(index=es_index, body=es_req, routing=task)
|
||||||
|
|
||||||
|
hits = nested_get(es_res, ("hits", "hits"))
|
||||||
|
if not hits:
|
||||||
|
return 0, 0
|
||||||
|
|
||||||
|
max_iter = hits[0]["_source"]["iter"]
|
||||||
|
url_min_buckets = nested_get(es_res, ("aggregations", "url_min", "buckets"))
|
||||||
|
if not url_min_buckets:
|
||||||
|
return 0, max_iter
|
||||||
|
|
||||||
|
min_iter = url_min_buckets[0]["max_iter"]["value"]
|
||||||
|
return int(min_iter), max_iter
|
||||||
|
@ -395,6 +395,78 @@
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
get_debug_image_event {
|
||||||
|
"2.11": {
|
||||||
|
description: "Return the last debug image per metric and variant for the provided iteration"
|
||||||
|
request {
|
||||||
|
type: object
|
||||||
|
required: [task, metric, variant]
|
||||||
|
properties {
|
||||||
|
task {
|
||||||
|
description: "Task ID"
|
||||||
|
type: string
|
||||||
|
}
|
||||||
|
metric {
|
||||||
|
description: "Metric name"
|
||||||
|
type: string
|
||||||
|
}
|
||||||
|
variant {
|
||||||
|
description: "Metric variant"
|
||||||
|
type: string
|
||||||
|
}
|
||||||
|
iteration {
|
||||||
|
description: "The latest iteration to bring debug image from. If not specified then the latest reported image is retrieved"
|
||||||
|
type: integer
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
response {
|
||||||
|
type: object
|
||||||
|
properties {
|
||||||
|
event {
|
||||||
|
description: "The latest debug image for the specifed iteration"
|
||||||
|
type: object
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
get_debug_image_iterations {
|
||||||
|
"2.11": {
|
||||||
|
description: "Return the min and max iterations for which valid urls are present"
|
||||||
|
request {
|
||||||
|
type: object
|
||||||
|
required: [task, metric, variant]
|
||||||
|
properties {
|
||||||
|
task {
|
||||||
|
description: "Task ID"
|
||||||
|
type: string
|
||||||
|
}
|
||||||
|
metric {
|
||||||
|
description: "Metric name"
|
||||||
|
type: string
|
||||||
|
}
|
||||||
|
variant {
|
||||||
|
description: "Metric variant"
|
||||||
|
type: string
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
response {
|
||||||
|
type: object
|
||||||
|
properties {
|
||||||
|
min_iteration {
|
||||||
|
description: "Mininal iteration for which a non recycled image exists"
|
||||||
|
type: integer
|
||||||
|
}
|
||||||
|
max_iteration {
|
||||||
|
description: "Maximal iteration for which an image was reported"
|
||||||
|
type: integer
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
get_task_metrics{
|
get_task_metrics{
|
||||||
"2.7": {
|
"2.7": {
|
||||||
description: "For each task, get a list of metrics for which the requested event type was reported"
|
description: "For each task, get a list of metrics for which the requested event type was reported"
|
||||||
|
@ -13,6 +13,8 @@ from apiserver.apimodels.events import (
|
|||||||
TaskMetricsRequest,
|
TaskMetricsRequest,
|
||||||
LogEventsRequest,
|
LogEventsRequest,
|
||||||
LogOrderEnum,
|
LogOrderEnum,
|
||||||
|
DebugImageIterationsRequest,
|
||||||
|
DebugImageEventRequest,
|
||||||
)
|
)
|
||||||
from apiserver.bll.event import EventBLL
|
from apiserver.bll.event import EventBLL
|
||||||
from apiserver.bll.event.event_metrics import EventMetrics
|
from apiserver.bll.event.event_metrics import EventMetrics
|
||||||
@ -624,6 +626,49 @@ def get_debug_images(call, company_id, request: DebugImagesRequest):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@endpoint(
|
||||||
|
"events.get_debug_image_event",
|
||||||
|
min_version="2.11",
|
||||||
|
request_data_model=DebugImageEventRequest,
|
||||||
|
)
|
||||||
|
def get_debug_image(call, company_id, request: DebugImageEventRequest):
|
||||||
|
task = task_bll.assert_exists(
|
||||||
|
company_id, task_ids=[request.task], allow_public=True, only=("company",)
|
||||||
|
)[0]
|
||||||
|
call.result.data = {
|
||||||
|
"event": event_bll.debug_images_iterator.get_debug_image_event(
|
||||||
|
company_id=task.company,
|
||||||
|
task=request.task,
|
||||||
|
metric=request.metric,
|
||||||
|
variant=request.variant,
|
||||||
|
iteration=request.iteration,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@endpoint(
|
||||||
|
"events.get_debug_image_iterations",
|
||||||
|
min_version="2.11",
|
||||||
|
request_data_model=DebugImageIterationsRequest,
|
||||||
|
)
|
||||||
|
def get_debug_image_iterations(call, company_id, request: DebugImageIterationsRequest):
|
||||||
|
task = task_bll.assert_exists(
|
||||||
|
company_id, task_ids=[request.task], allow_public=True, only=("company",)
|
||||||
|
)[0]
|
||||||
|
|
||||||
|
min_iter, max_iter = event_bll.debug_images_iterator.get_debug_image_iterations(
|
||||||
|
company_id=task.company,
|
||||||
|
task=request.task,
|
||||||
|
metric=request.metric,
|
||||||
|
variant=request.variant,
|
||||||
|
)
|
||||||
|
|
||||||
|
call.result.data = {
|
||||||
|
"max_iteration": max_iter,
|
||||||
|
"min_iteration": min_iter,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
@endpoint("events.get_task_metrics", request_data_model=TaskMetricsRequest)
|
@endpoint("events.get_task_metrics", request_data_model=TaskMetricsRequest)
|
||||||
def get_tasks_metrics(call: APICall, company_id, request: TaskMetricsRequest):
|
def get_tasks_metrics(call: APICall, company_id, request: TaskMetricsRequest):
|
||||||
task = task_bll.assert_exists(
|
task = task_bll.assert_exists(
|
||||||
|
173
apiserver/tests/automated/test_task_debug_images.py
Normal file
173
apiserver/tests/automated/test_task_debug_images.py
Normal file
@ -0,0 +1,173 @@
|
|||||||
|
from functools import partial
|
||||||
|
from typing import Sequence
|
||||||
|
|
||||||
|
|
||||||
|
from apiserver.es_factory import es_factory
|
||||||
|
from apiserver.tests.automated import TestService
|
||||||
|
|
||||||
|
|
||||||
|
class TestTaskDebugImages(TestService):
|
||||||
|
def setUp(self, version="2.11"):
|
||||||
|
super().setUp(version=version)
|
||||||
|
|
||||||
|
def _temp_task(self, name="test task events"):
|
||||||
|
task_input = dict(
|
||||||
|
name=name, type="training", input=dict(mapping={}, view=dict(entries=[])),
|
||||||
|
)
|
||||||
|
return self.create_temp("tasks", **task_input)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _create_task_event(task, iteration, **kwargs):
|
||||||
|
return {
|
||||||
|
"worker": "test",
|
||||||
|
"type": "training_debug_image",
|
||||||
|
"task": task,
|
||||||
|
"iter": iteration,
|
||||||
|
"timestamp": kwargs.get("timestamp") or es_factory.get_timestamp_millis(),
|
||||||
|
**kwargs,
|
||||||
|
}
|
||||||
|
|
||||||
|
def test_get_debug_image(self):
|
||||||
|
task = self._temp_task()
|
||||||
|
metric = "Metric1"
|
||||||
|
variant = "Variant1"
|
||||||
|
|
||||||
|
# test empty
|
||||||
|
res = self.api.events.get_debug_image_iterations(
|
||||||
|
task=task, metric=metric, variant=variant
|
||||||
|
)
|
||||||
|
self.assertEqual(res.min_iteration, 0)
|
||||||
|
self.assertEqual(res.max_iteration, 0)
|
||||||
|
res = self.api.events.get_debug_image_event(
|
||||||
|
task=task, metric=metric, variant=variant
|
||||||
|
)
|
||||||
|
self.assertEqual(res.event, None)
|
||||||
|
|
||||||
|
# test existing events
|
||||||
|
iterations = 10
|
||||||
|
unique_images = 4
|
||||||
|
events = [
|
||||||
|
self._create_task_event(
|
||||||
|
task=task,
|
||||||
|
iteration=n,
|
||||||
|
metric=metric,
|
||||||
|
variant=variant,
|
||||||
|
url=f"{metric}_{variant}_{n % unique_images}",
|
||||||
|
)
|
||||||
|
for n in range(iterations)
|
||||||
|
]
|
||||||
|
self.send_batch(events)
|
||||||
|
res = self.api.events.get_debug_image_iterations(
|
||||||
|
task=task, metric=metric, variant=variant
|
||||||
|
)
|
||||||
|
self.assertEqual(res.max_iteration, iterations-1)
|
||||||
|
self.assertEqual(res.min_iteration, max(0, iterations - unique_images))
|
||||||
|
|
||||||
|
# if iteration is not specified then return the event from the last one
|
||||||
|
res = self.api.events.get_debug_image_event(
|
||||||
|
task=task, metric=metric, variant=variant
|
||||||
|
)
|
||||||
|
self._assertEqualEvent(res.event, events[-1])
|
||||||
|
|
||||||
|
# else from the specific iteration
|
||||||
|
iteration = 8
|
||||||
|
res = self.api.events.get_debug_image_event(
|
||||||
|
task=task, metric=metric, variant=variant, iteration=iteration
|
||||||
|
)
|
||||||
|
self._assertEqualEvent(res.event, events[iteration])
|
||||||
|
|
||||||
|
def _assertEqualEvent(self, ev1: dict, ev2: dict):
|
||||||
|
self.assertEqual(ev1["iter"], ev2["iter"])
|
||||||
|
self.assertEqual(ev1["url"], ev2["url"])
|
||||||
|
|
||||||
|
def test_task_debug_images(self):
|
||||||
|
task = self._temp_task()
|
||||||
|
metric = "Metric1"
|
||||||
|
variants = [("Variant1", 7), ("Variant2", 4)]
|
||||||
|
iterations = 10
|
||||||
|
|
||||||
|
# test empty
|
||||||
|
res = self.api.events.debug_images(
|
||||||
|
metrics=[{"task": task, "metric": metric}], iters=5,
|
||||||
|
)
|
||||||
|
self.assertFalse(res.metrics)
|
||||||
|
|
||||||
|
# create events
|
||||||
|
events = [
|
||||||
|
self._create_task_event(
|
||||||
|
task=task,
|
||||||
|
iteration=n,
|
||||||
|
metric=metric,
|
||||||
|
variant=variant,
|
||||||
|
url=f"{metric}_{variant}_{n % unique_images}",
|
||||||
|
)
|
||||||
|
for n in range(iterations)
|
||||||
|
for (variant, unique_images) in variants
|
||||||
|
]
|
||||||
|
self.send_batch(events)
|
||||||
|
|
||||||
|
# init testing
|
||||||
|
unique_images = [unique for (_, unique) in variants]
|
||||||
|
scroll_id = None
|
||||||
|
assert_debug_images = partial(
|
||||||
|
self._assertDebugImages,
|
||||||
|
task=task,
|
||||||
|
metric=metric,
|
||||||
|
max_iter=iterations - 1,
|
||||||
|
unique_images=unique_images,
|
||||||
|
)
|
||||||
|
|
||||||
|
# test forward navigation
|
||||||
|
for page in range(3):
|
||||||
|
scroll_id = assert_debug_images(scroll_id=scroll_id, expected_page=page)
|
||||||
|
|
||||||
|
# test backwards navigation
|
||||||
|
scroll_id = assert_debug_images(
|
||||||
|
scroll_id=scroll_id, expected_page=0, navigate_earlier=False
|
||||||
|
)
|
||||||
|
|
||||||
|
# beyond the latest iteration and back
|
||||||
|
res = self.api.events.debug_images(
|
||||||
|
metrics=[{"task": task, "metric": metric}],
|
||||||
|
iters=5,
|
||||||
|
scroll_id=scroll_id,
|
||||||
|
navigate_earlier=False,
|
||||||
|
)
|
||||||
|
self.assertEqual(len(res["metrics"][0]["iterations"]), 0)
|
||||||
|
assert_debug_images(scroll_id=scroll_id, expected_page=1)
|
||||||
|
|
||||||
|
# refresh
|
||||||
|
assert_debug_images(scroll_id=scroll_id, expected_page=0, refresh=True)
|
||||||
|
|
||||||
|
def _assertDebugImages(
|
||||||
|
self,
|
||||||
|
task,
|
||||||
|
metric,
|
||||||
|
max_iter: int,
|
||||||
|
unique_images: Sequence[int],
|
||||||
|
scroll_id,
|
||||||
|
expected_page: int,
|
||||||
|
iters: int = 5,
|
||||||
|
**extra_params,
|
||||||
|
):
|
||||||
|
res = self.api.events.debug_images(
|
||||||
|
metrics=[{"task": task, "metric": metric}],
|
||||||
|
iters=iters,
|
||||||
|
scroll_id=scroll_id,
|
||||||
|
**extra_params,
|
||||||
|
)
|
||||||
|
data = res["metrics"][0]
|
||||||
|
self.assertEqual(data["task"], task)
|
||||||
|
self.assertEqual(data["metric"], metric)
|
||||||
|
left_iterations = max(0, max(unique_images) - expected_page * iters)
|
||||||
|
self.assertEqual(len(data["iterations"]), min(iters, left_iterations))
|
||||||
|
for it in data["iterations"]:
|
||||||
|
events_per_iter = sum(
|
||||||
|
1 for unique in unique_images if unique > max_iter - it["iter"]
|
||||||
|
)
|
||||||
|
self.assertEqual(len(it["events"]), events_per_iter)
|
||||||
|
return res.scroll_id
|
||||||
|
|
||||||
|
def send_batch(self, events):
|
||||||
|
_, data = self.api.send_batch("events.add_batch", events)
|
||||||
|
return data
|
@ -1,10 +1,6 @@
|
|||||||
"""
|
|
||||||
Comprehensive test of all(?) use cases of datasets and frames
|
|
||||||
"""
|
|
||||||
import json
|
import json
|
||||||
import operator
|
import operator
|
||||||
import unittest
|
import unittest
|
||||||
from functools import partial
|
|
||||||
from statistics import mean
|
from statistics import mean
|
||||||
from typing import Sequence, Optional, Tuple
|
from typing import Sequence, Optional, Tuple
|
||||||
|
|
||||||
@ -99,95 +95,6 @@ class TestTaskEvents(TestService):
|
|||||||
self.assertEqual(iter_count - 1, metric_data.max_value)
|
self.assertEqual(iter_count - 1, metric_data.max_value)
|
||||||
self.assertEqual(0, metric_data.min_value)
|
self.assertEqual(0, metric_data.min_value)
|
||||||
|
|
||||||
def test_task_debug_images(self):
|
|
||||||
task = self._temp_task()
|
|
||||||
metric = "Metric1"
|
|
||||||
variants = [("Variant1", 7), ("Variant2", 4)]
|
|
||||||
iterations = 10
|
|
||||||
|
|
||||||
# test empty
|
|
||||||
res = self.api.events.debug_images(
|
|
||||||
metrics=[{"task": task, "metric": metric}], iters=5,
|
|
||||||
)
|
|
||||||
self.assertFalse(res.metrics)
|
|
||||||
|
|
||||||
# create events
|
|
||||||
events = [
|
|
||||||
self._create_task_event(
|
|
||||||
"training_debug_image",
|
|
||||||
task=task,
|
|
||||||
iteration=n,
|
|
||||||
metric=metric,
|
|
||||||
variant=variant,
|
|
||||||
url=f"{metric}_{variant}_{n % unique_images}",
|
|
||||||
)
|
|
||||||
for n in range(iterations)
|
|
||||||
for (variant, unique_images) in variants
|
|
||||||
]
|
|
||||||
self.send_batch(events)
|
|
||||||
|
|
||||||
# init testing
|
|
||||||
unique_images = [unique for (_, unique) in variants]
|
|
||||||
scroll_id = None
|
|
||||||
assert_debug_images = partial(
|
|
||||||
self._assertDebugImages,
|
|
||||||
task=task,
|
|
||||||
metric=metric,
|
|
||||||
max_iter=iterations - 1,
|
|
||||||
unique_images=unique_images,
|
|
||||||
)
|
|
||||||
|
|
||||||
# test forward navigation
|
|
||||||
for page in range(3):
|
|
||||||
scroll_id = assert_debug_images(scroll_id=scroll_id, expected_page=page)
|
|
||||||
|
|
||||||
# test backwards navigation
|
|
||||||
scroll_id = assert_debug_images(
|
|
||||||
scroll_id=scroll_id, expected_page=0, navigate_earlier=False
|
|
||||||
)
|
|
||||||
|
|
||||||
# beyond the latest iteration and back
|
|
||||||
res = self.api.events.debug_images(
|
|
||||||
metrics=[{"task": task, "metric": metric}],
|
|
||||||
iters=5,
|
|
||||||
scroll_id=scroll_id,
|
|
||||||
navigate_earlier=False,
|
|
||||||
)
|
|
||||||
self.assertEqual(len(res["metrics"][0]["iterations"]), 0)
|
|
||||||
assert_debug_images(scroll_id=scroll_id, expected_page=1)
|
|
||||||
|
|
||||||
# refresh
|
|
||||||
assert_debug_images(scroll_id=scroll_id, expected_page=0, refresh=True)
|
|
||||||
|
|
||||||
def _assertDebugImages(
|
|
||||||
self,
|
|
||||||
task,
|
|
||||||
metric,
|
|
||||||
max_iter: int,
|
|
||||||
unique_images: Sequence[int],
|
|
||||||
scroll_id,
|
|
||||||
expected_page: int,
|
|
||||||
iters: int = 5,
|
|
||||||
**extra_params,
|
|
||||||
):
|
|
||||||
res = self.api.events.debug_images(
|
|
||||||
metrics=[{"task": task, "metric": metric}],
|
|
||||||
iters=iters,
|
|
||||||
scroll_id=scroll_id,
|
|
||||||
**extra_params,
|
|
||||||
)
|
|
||||||
data = res["metrics"][0]
|
|
||||||
self.assertEqual(data["task"], task)
|
|
||||||
self.assertEqual(data["metric"], metric)
|
|
||||||
left_iterations = max(0, max(unique_images) - expected_page * iters)
|
|
||||||
self.assertEqual(len(data["iterations"]), min(iters, left_iterations))
|
|
||||||
for it in data["iterations"]:
|
|
||||||
events_per_iter = sum(
|
|
||||||
1 for unique in unique_images if unique > max_iter - it["iter"]
|
|
||||||
)
|
|
||||||
self.assertEqual(len(it["events"]), events_per_iter)
|
|
||||||
return res.scroll_id
|
|
||||||
|
|
||||||
def test_error_events(self):
|
def test_error_events(self):
|
||||||
task = self._temp_task()
|
task = self._temp_task()
|
||||||
events = [
|
events = [
|
||||||
|
@ -1,188 +0,0 @@
|
|||||||
from __future__ import unicode_literals
|
|
||||||
|
|
||||||
from collections import OrderedDict
|
|
||||||
|
|
||||||
import attr
|
|
||||||
import related
|
|
||||||
import six
|
|
||||||
from related import StringField, FloatField, IntegerField, BooleanField
|
|
||||||
|
|
||||||
|
|
||||||
@attr.s
|
|
||||||
class AtomicType(object):
|
|
||||||
python = attr.ib()
|
|
||||||
related = attr.ib()
|
|
||||||
|
|
||||||
|
|
||||||
ATOMIC_TYPES = {
|
|
||||||
'integer': AtomicType(int, IntegerField),
|
|
||||||
'number': AtomicType(float, FloatField),
|
|
||||||
'string': AtomicType(str, StringField),
|
|
||||||
'boolean': AtomicType(bool, BooleanField),
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def resolve_ref(definitions, value):
|
|
||||||
ref = value.get('$ref')
|
|
||||||
if ref:
|
|
||||||
name = ref.split('/')[-1]
|
|
||||||
return definitions[name]
|
|
||||||
one_of = value.get('oneOf')
|
|
||||||
if one_of:
|
|
||||||
one_of = list(one_of)
|
|
||||||
one_of.remove(dict(type='null'))
|
|
||||||
assert len(one_of) == 1
|
|
||||||
return dict(value, **resolve_ref(definitions, one_of.pop()))
|
|
||||||
return value
|
|
||||||
|
|
||||||
|
|
||||||
def normalize_type(typ):
|
|
||||||
if isinstance(typ, six.string_types):
|
|
||||||
return typ
|
|
||||||
return (set(typ) - {'null'}).pop()
|
|
||||||
|
|
||||||
|
|
||||||
@attr.s
|
|
||||||
class RelatedBuilder(object):
|
|
||||||
"""
|
|
||||||
Converts jsonschema to related class or field.
|
|
||||||
:param name: Object name. Will be used as the name of the class and to detect recursive objects,
|
|
||||||
which are not supported.
|
|
||||||
:param schema: jsonschema which is the base of the object
|
|
||||||
:param required: In case of child fields, whether the field is required. Only used in ``to_field``.
|
|
||||||
:param definitions: Dictionary to resolve definitions. Defaults to schema['definitions'].
|
|
||||||
:param default: Default value of field. Only used in ``to_field``.
|
|
||||||
"""
|
|
||||||
|
|
||||||
name = attr.ib(type=six.text_type)
|
|
||||||
schema = attr.ib(type=dict, repr=False)
|
|
||||||
required = attr.ib(type=bool, default=False)
|
|
||||||
definitions = attr.ib(type=dict, default=None, repr=False)
|
|
||||||
default = attr.ib(default=attr.Factory(lambda: attr.NOTHING))
|
|
||||||
|
|
||||||
def __attrs_post_init__(self):
|
|
||||||
self.schema = resolve_ref(self.definitions, self.schema)
|
|
||||||
self.type = normalize_type(self.schema['type'])
|
|
||||||
self.definitions = self.definitions or self.schema.get('definitions')
|
|
||||||
|
|
||||||
def to_field(self):
|
|
||||||
"""
|
|
||||||
Creates the appropriate ``related`` field from instance.
|
|
||||||
NOTE: Items and nesting level of nested arrays will not be checked.
|
|
||||||
"""
|
|
||||||
if self.type in ATOMIC_TYPES:
|
|
||||||
field = ATOMIC_TYPES[self.type].related
|
|
||||||
return field(default=self.default, required=self.required)
|
|
||||||
if self.type == 'array':
|
|
||||||
sub_schema = self.schema['items']
|
|
||||||
builder = RelatedBuilder(
|
|
||||||
'{}_items'.format(self.name), sub_schema, definitions=self.definitions
|
|
||||||
)
|
|
||||||
return related.SequenceField(
|
|
||||||
list if builder.type == 'array' else builder.to_class(),
|
|
||||||
default=attr.Factory(list),
|
|
||||||
)
|
|
||||||
if self.schema.get('additionalProperties') is True:
|
|
||||||
return attr.ib(type=dict, default=None)
|
|
||||||
|
|
||||||
return related.ChildField(
|
|
||||||
self.to_class(), default=self.default, required=self.required
|
|
||||||
)
|
|
||||||
|
|
||||||
def to_class(self):
|
|
||||||
"""
|
|
||||||
Creates a related class.
|
|
||||||
"""
|
|
||||||
required = self.schema.get('required', [])
|
|
||||||
|
|
||||||
if self.type in ATOMIC_TYPES:
|
|
||||||
return ATOMIC_TYPES[self.type].python
|
|
||||||
|
|
||||||
if self.type == 'array':
|
|
||||||
raise RuntimeError(self, 'Cannot convert array to related class')
|
|
||||||
|
|
||||||
assert self.type and normalize_type(self.type) == 'object', (
|
|
||||||
self.type,
|
|
||||||
list(self.schema),
|
|
||||||
)
|
|
||||||
properties = sorted(
|
|
||||||
tuple(
|
|
||||||
(
|
|
||||||
inner_name,
|
|
||||||
RelatedBuilder(
|
|
||||||
name=inner_name,
|
|
||||||
schema=value,
|
|
||||||
required=inner_name in required,
|
|
||||||
definitions=definitions,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
for inner_name, value in self.schema['properties'].items()
|
|
||||||
if inner_name != self.name
|
|
||||||
),
|
|
||||||
key=lambda pair: pair[1].required,
|
|
||||||
reverse=True,
|
|
||||||
)
|
|
||||||
return related.mutable(
|
|
||||||
type(
|
|
||||||
self.name,
|
|
||||||
(object,),
|
|
||||||
OrderedDict([(key, builder.to_field()) for key, builder in properties]),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class Visitor(object):
|
|
||||||
"""Base class for visitors."""
|
|
||||||
|
|
||||||
def visit(self, node, *args, **kwargs):
|
|
||||||
"""Visit a node.
|
|
||||||
|
|
||||||
Calls ``visit_CLASSNAME`` on itself passing ``node``, where
|
|
||||||
``CLASSNAME`` is the node's class. If the visitor does not implement an
|
|
||||||
appropriate visitation method, will go up the
|
|
||||||
`MRO <https://www.python.org/download/releases/2.3/mro/>`_ until a
|
|
||||||
match is found.
|
|
||||||
|
|
||||||
If the search exhausts all classes of node, raises a
|
|
||||||
:class:`~exceptions.NotImplementedError`.
|
|
||||||
|
|
||||||
:param node: The node to visit.
|
|
||||||
:return: The return value of the called visitation function.
|
|
||||||
"""
|
|
||||||
if isinstance(node, type):
|
|
||||||
mro = node.mro()
|
|
||||||
else:
|
|
||||||
mro = type(node).mro()
|
|
||||||
for cls in mro:
|
|
||||||
meth = getattr(self, 'visit_' + cls.__name__, None)
|
|
||||||
if meth is None:
|
|
||||||
continue
|
|
||||||
return meth(node, *args, **kwargs)
|
|
||||||
|
|
||||||
raise NotImplementedError(
|
|
||||||
'No visitation method visit_{}'.format(node.__class__.__name__)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class SchemaCleaner(Visitor):
|
|
||||||
def __init__(self, definitions):
|
|
||||||
self.definitions = definitions
|
|
||||||
|
|
||||||
def visit_dict(self, obj, schema):
|
|
||||||
schema = resolve_ref(self.definitions, schema)
|
|
||||||
if schema.get('additionalProperties') is True:
|
|
||||||
return
|
|
||||||
props = schema['properties']
|
|
||||||
for key, value in list(obj.items()):
|
|
||||||
if key in props:
|
|
||||||
self.visit(value, props[key])
|
|
||||||
else:
|
|
||||||
del obj[key]
|
|
||||||
|
|
||||||
def visit_list(self, obj, schema):
|
|
||||||
for value in obj:
|
|
||||||
self.visit(value, schema['items'])
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def visit_object(obj, schema):
|
|
||||||
pass
|
|
Loading…
Reference in New Issue
Block a user