Rename migration script

Support refresh flag in debug image samples
Remove silent_dequeue_fail param to prevent status change in case task wasn't queued
Add organizations.get_user_companies
Fix reset should also reset active_duration
Add api_version to server.info
This commit is contained in:
allegroai 2021-01-05 18:09:34 +02:00
parent 618a0b9473
commit 3272d0f31f
12 changed files with 164 additions and 24 deletions

View File

@ -60,6 +60,7 @@ class TaskMetricVariant(Base):
class GetDebugImageSampleRequest(TaskMetricVariant):
iteration: Optional[int] = IntField()
scroll_id: Optional[str] = StringField()
refresh: bool = BoolField(default=False)
class NextDebugImageSampleRequest(Base):

View File

@ -211,6 +211,7 @@ class DebugSampleHistory:
metric: str,
variant: str,
iteration: Optional[int] = None,
refresh: bool = False,
state_id: str = None,
) -> DebugSampleHistoryResult:
"""
@ -225,15 +226,7 @@ class DebugSampleHistory:
def init_state(state_: DebugSampleHistoryState):
state_.task = task
state_.metric = metric
variant_iterations = self._get_variant_iterations(
es_index=es_index, task=task, metric=metric
)
state_.variant_states = [
VariantState(
name=var_name, min_iteration=min_iter, max_iteration=max_iter
)
for var_name, min_iter, max_iter in variant_iterations
]
self._reset_variant_states(es_index, state=state_)
def validate_state(state_: DebugSampleHistoryState):
if state_.task != task or state_.metric != metric:
@ -241,6 +234,8 @@ class DebugSampleHistory:
"Task and metric stored in the state do not match the passed ones",
scroll_id=state_.id,
)
if refresh:
self._reset_variant_states(es_index, state=state_)
state: DebugSampleHistoryState
with self.cache_manager.get_or_create_state(
@ -291,6 +286,17 @@ class DebugSampleHistory:
)
return res
def _reset_variant_states(self, es_index, state: DebugSampleHistoryState):
variant_iterations = self._get_variant_iterations(
es_index=es_index, task=state.task, metric=state.metric
)
state.variant_states = [
VariantState(
name=var_name, min_iteration=min_iter, max_iteration=max_iter
)
for var_name, min_iter, max_iter in variant_iterations
]
def _get_variant_iterations(
self,
es_index: str,

View File

@ -651,9 +651,8 @@ class TaskBLL:
company_id: str,
status_message: str,
status_reason: str,
silent_dequeue_fail=False,
):
cls.dequeue(task, company_id, silent_dequeue_fail)
cls.dequeue(task, company_id)
return ChangeStatusRequest(
task=task,

View File

@ -5,19 +5,25 @@ from pymongo.database import Database
def _add_active_duration(db: Database):
active_duration = "active_duration"
query = {active_duration: {"$eq": None}}
active_duration_key = "active_duration"
query = {"$or": [{active_duration_key: {"$eq": None}}, {active_duration_key: {"$eq": 0}}]}
collection = db["task"]
for doc in collection.find(
filter=query, projection=[active_duration, "status", "started", "completed"]
filter=query, projection=[active_duration_key, "status", "started", "completed"]
):
started = doc.get("started")
completed = doc.get("completed")
running = doc.get("status") == "running"
if started and doc.get(active_duration) is None:
active_duration_value = doc.get(active_duration_key)
if active_duration_value == 0:
collection.update_one(
{"_id": doc["_id"]},
{"$set": {active_duration: _get_active_duration(completed, running, started)}},
{"$set": {active_duration_key: None}},
)
elif started and active_duration_value is None:
collection.update_one(
{"_id": doc["_id"]},
{"$set": {active_duration_key: _get_active_duration(completed, running, started)}},
)

View File

@ -439,6 +439,10 @@
description: "The iteration to bring debug image from. If not specified then the latest reported image is retrieved"
type: integer
}
refresh {
description: "If set then scroll state will be refreshed to reflect the latest changes in the debug images"
type: boolean
}
scroll_id {
type: string
description: "Scroll ID from the previous call to get_debug_image_sample or empty"

View File

@ -45,4 +45,61 @@ get_tags {
}
}
}
}
get_user_companies {
"2.12" {
description: "Get details for all companies associated with the current user"
request {
type: object
properties {}
additionalProperties: false
}
response {
type: object
properties {
companies {
description: "List of company information entries. First company is the user's own company"
type: array
items {
type: object
properties {
id {
description: "Company ID"
type: string
}
name {
description: "Company name"
type: string
}
allocated {
description: "Number of users allocated for company"
type: integer
}
owners {
description: "Company owners"
type: array
items {
type: object
properties {
id {
description: "User ID"
type: string
}
name {
description: "User Name"
type: string
}
avatar {
description: "User avatar (URL or base64-encoded data)"
type: string
}
}
}
}
}
}
}
}
}
}
}

View File

@ -77,6 +77,21 @@ info {
description: "Server UID"
type: string
}
api_version {
description: "Max API version supported"
type: string
}
}
}
}
"2.12": ${info."2.8"} {
response {
type: object
properties {
api_version {
description: "Max API version supported"
type: string
}
}
}
}

View File

@ -643,6 +643,7 @@ def get_debug_image_sample(call, company_id, request: GetDebugImageSampleRequest
metric=request.metric,
variant=request.variant,
iteration=request.iteration,
refresh=request.refresh,
state_id=request.scroll_id,
)
call.result.data = attr.asdict(res, recurse=False)

View File

@ -1,7 +1,9 @@
from collections import defaultdict
from operator import itemgetter
from apiserver.apimodels.organization import TagsRequest
from apiserver.bll.organization import OrgBLL, Tags
from apiserver.database.model import User
from apiserver.service_repo import endpoint, APICall
from apiserver.services.utils import get_tags_filter_dictionary, get_tags_response
@ -20,3 +22,26 @@ def get_tags(call: APICall, company, request: TagsRequest):
ret[field] |= vals
call.result.data = get_tags_response(ret)
@endpoint("organization.get_user_companies")
def get_user_companies(call: APICall, company_id: str, _):
users = [
{
"id": u.id,
"name": u.name,
"avatar": u.avatar,
}
for u in User.objects(company=company_id).only("avatar", "name", "company")
]
call.result.data = {
"companies": [
{
"id": company_id,
"name": call.identity.company_name,
"allocated": len(users),
"owners": sorted(users, key=itemgetter("name")),
}
]
}

View File

@ -67,6 +67,12 @@ def info_2_8(call: APICall):
call.result.data["uid"] = Settings.get_by_key(SettingKeys.server__uuid)
@endpoint("server.info", min_version="2.12")
def info_2_8(call: APICall):
info(call)
call.result.data["api_version"] = str(ServiceRepo.max_endpoint_version())
@endpoint(
"server.report_stats_option",
request_data_model=ReportStatsOptionRequest,

View File

@ -879,7 +879,7 @@ def reset(call: APICall, company_id, request: ResetRequest):
force=force,
status_reason="reset",
status_message="reset",
).execute(started=None, completed=None, published=None, **updates)
).execute(started=None, completed=None, published=None, active_duration=None, **updates)
)
# do not return artifacts since they are not serializable
@ -904,13 +904,16 @@ def archive(call: APICall, company_id, request: ArchiveRequest):
only=("id", "execution", "status", "project", "system_tags"),
)
for task in tasks:
TaskBLL.dequeue_and_change_status(
task,
company_id,
request.status_message,
request.status_reason,
silent_dequeue_fail=True,
)
try:
TaskBLL.dequeue_and_change_status(
task,
company_id,
request.status_message,
request.status_reason,
)
except APIError:
# dequeue may fail if the task was not enqueued
pass
task.update(
status_message=request.status_message,
status_reason=request.status_reason,

View File

@ -0,0 +1,17 @@
from apiserver.tests.automated import TestService
class TestOrganization(TestService):
def setUp(self, version="2.12"):
super().setUp(version=version)
def test_get_user_companies(self):
company = self.api.organization.get_user_companies().companies[0]
self.assertEqual(len(company.owners), company.allocated)
users = company.owners
self.assertTrue(users)
self.assertTrue(u1.name < u2.name for u1, u2 in zip(users, users[1:]))
for user in company.owners:
self.assertTrue(user.id)
self.assertTrue(user.name)
self.assertIn("avatar", user)